From 3b4dc2bd09da1458aaa610167a1ae0137a173144 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 22 Apr 2026 16:53:04 +0000 Subject: [PATCH 1/2] feat: TorchTRT Annotation Layer for Cuda generated kernels --- .github/workflows/build-test-linux-x86_64.yml | 6 + docsrc/py_api/annotation.rst | 156 ++++ docsrc/py_api/index.rst | 1 + .../auto_cuda_kernel_plugin_annotation.py | 118 +++ .../manual_cuda_kernel_plugin_annotation.py | 160 +++++ py/torch_tensorrt/__init__.py | 3 + py/torch_tensorrt/_features.py | 8 +- py/torch_tensorrt/annotation/__init__.py | 115 +++ .../annotation/_custom_plugin/__init__.py | 305 ++++++++ .../annotation/_custom_plugin/_descriptor.py | 240 +++++++ .../annotation/_custom_plugin/_nvrtc.py | 64 ++ .../annotation/_kernel_plugin.py | 676 ++++++++++++++++++ py/torch_tensorrt/annotation/_kernel_spec.py | 165 +++++ py/torch_tensorrt/annotation/_specs.py | 30 + pyproject.toml | 4 + setup.py | 4 + tests/py/annotation/__init__.py | 0 tests/py/annotation/conftest.py | 144 ++++ .../test_auto_cuda_kernel_plugin.py | 372 ++++++++++ .../test_manual_cuda_kernel_plugin.py | 302 ++++++++ 20 files changed, 2868 insertions(+), 5 deletions(-) create mode 100644 docsrc/py_api/annotation.rst create mode 100644 examples/dynamo/auto_cuda_kernel_plugin_annotation.py create mode 100644 examples/dynamo/manual_cuda_kernel_plugin_annotation.py create mode 100644 py/torch_tensorrt/annotation/__init__.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/__init__.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py create mode 100644 py/torch_tensorrt/annotation/_kernel_plugin.py create mode 100644 py/torch_tensorrt/annotation/_kernel_spec.py create mode 100644 py/torch_tensorrt/annotation/_specs.py create mode 100644 tests/py/annotation/__init__.py create mode 100644 tests/py/annotation/conftest.py create mode 100644 tests/py/annotation/test_auto_cuda_kernel_plugin.py create mode 100644 tests/py/annotation/test_manual_cuda_kernel_plugin.py diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index a437c284c0..b4fc8b2d81 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -459,6 +459,12 @@ jobs: python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_flashinfer_rmsnorm.py popd + pushd . + # cuda-python is an optional runtime dep for the torch_tensorrt.annotation QDP layer. + python -m pip install cuda-python + cd tests/py/annotation + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_annotation_test_results.xml . + popd L2-torchscript-tests: name: ${{ matrix.display-name }} diff --git a/docsrc/py_api/annotation.rst b/docsrc/py_api/annotation.rst new file mode 100644 index 0000000000..8ab16642c9 --- /dev/null +++ b/docsrc/py_api/annotation.rst @@ -0,0 +1,156 @@ +.. _torch_tensorrt_annotation_py: + +torch_tensorrt.annotation +========================== + +.. currentmodule:: torch_tensorrt.annotation + +.. automodule:: torch_tensorrt.annotation + +.. note:: + + This module is **experimental**. It requires ``cuda-python`` at runtime + and TensorRT ``>=10.7.0`` (and not ``10.14.x``) for Quick Deployable + Plugin (QDP) support. Install ``cuda-python`` with ``pip install + cuda-python``. + +Overview +-------- + +The ``annotation`` module registers NVRTC-compiled CUDA C++ kernels as +TensorRT Quick Deployable Plugins with full Ahead-of-Time (AOT) +compilation support. It offers two entry points that trade +declarativeness for flexibility — start with :func:`auto_cuda_kernel_plugin` and +drop down to :func:`manual_cuda_kernel_plugin` only when your kernel falls outside +the declarative DSL: + +.. list-table:: + :header-rows: 1 + :widths: 28 36 36 + + * - Entry point + - What you provide + - What you get for free + * - :func:`auto_cuda_kernel_plugin` + - A :class:`KernelSpec` dataclass (source, inputs, outputs, extras, + geometry) + - Meta / eager / AOT functions and the PyTorch schema — all derived + * - :func:`manual_cuda_kernel_plugin` + - ``aot_fn`` + ``eager_fn`` + a meta function decorated with the + one-shot decorator + - PyTorch op + TRT plugin + converter, registered together + +For unary-pointwise kernels, :func:`pointwise_aot` and +:func:`pointwise_eager` produce the two callables so users can plug them +directly into :func:`manual_cuda_kernel_plugin`. + +Declarative entry point +----------------------- + +.. autofunction:: auto_cuda_kernel_plugin + +KernelSpec DSL +^^^^^^^^^^^^^^ + +.. autoclass:: KernelSpec + :members: + +.. autoclass:: InputDecl + :members: + +.. autoclass:: ScalarInput + :members: + +.. autoclass:: OutputDecl + :members: + +Shape relations +""""""""""""""" + +.. autoclass:: SameAs + :members: + +.. autoclass:: ReduceDims + :members: + +Extra scalar args +""""""""""""""""" + +Extras are passed to the kernel between the input and output pointer +lists in :class:`KernelSpec` order. + +.. autoclass:: Numel + :members: + +.. autoclass:: DimSize + :members: + +Launch geometry +""""""""""""""" + +.. autoclass:: Elementwise + :members: + +.. autoclass:: Reduction + :members: + +.. autoclass:: Custom + :members: + +One-shot hand-written entry point +--------------------------------- + +.. autofunction:: manual_cuda_kernel_plugin + +Lower-level building blocks +--------------------------- + +.. autofunction:: cuda_python + +.. autofunction:: custom_plugin + +Spec class +^^^^^^^^^^ + +.. autoclass:: CudaPythonSpec + :members: + +Pointwise helpers +----------------- + +.. autofunction:: pointwise_aot + +.. autofunction:: pointwise_eager + +Kernel signature convention +--------------------------- + +All entry points assume the ``__global__`` kernel takes its arguments in +the fixed order:: + + (input_ptrs..., extras..., output_ptrs...) + +Pointers are ``void*`` cast to the appropriate element type. Extras +follow the order declared in :attr:`KernelSpec.extras` for the +declarative path, or the order your ``aot_fn`` builds for the manual +path. + +Error behavior +-------------- + +:func:`auto_cuda_kernel_plugin` validates the :class:`KernelSpec` at decorator +time and raises :class:`ValueError` for the common authoring mistakes: + +- Empty or duplicate-named ``inputs`` / ``outputs``. +- ``ReduceDims(input_idx=N)`` or ``SameAs(input_idx=N)`` where ``N`` is + out of range. +- ``Numel`` / ``DimSize`` referencing a name that is not an input. +- ``dtype_from`` pointing at an unknown input. +- ``Elementwise(layout='flat')`` with a multi-dimensional block tuple. +- Invalid block sizes, ``block_size`` in :class:`Reduction`, or a + non-callable :attr:`Custom.fn`. + +Shape-dependent errors — for example +``Elementwise(layout='nd', block=(16, 16))`` invoked against a 1-D +output — are raised at launch time in a clear ``ValueError`` because +the offending ranks are only known when concrete tensors arrive. diff --git a/docsrc/py_api/index.rst b/docsrc/py_api/index.rst index 689e754637..5e70a89038 100644 --- a/docsrc/py_api/index.rst +++ b/docsrc/py_api/index.rst @@ -13,6 +13,7 @@ Core dynamo logging runtime + annotation ../cli/torchtrtc ../indices/supported_ops diff --git a/examples/dynamo/auto_cuda_kernel_plugin_annotation.py b/examples/dynamo/auto_cuda_kernel_plugin_annotation.py new file mode 100644 index 0000000000..a3f4dd99d5 --- /dev/null +++ b/examples/dynamo/auto_cuda_kernel_plugin_annotation.py @@ -0,0 +1,118 @@ +""" +.. _auto_cuda_kernel_plugin_annotation: + +Declarative Custom Kernel via ``torch_tensorrt.annotation.auto_cuda_kernel_plugin`` +==================================================================================== + +``auto_cuda_kernel_plugin`` is the recommended entry point for exposing a +hand-written CUDA C++ kernel to both PyTorch eager and the Torch-TensorRT +compile path. + +You describe the kernel with a :class:`KernelSpec` dataclass (inputs, outputs, +extras, launch geometry) and the framework derives the meta function, the +eager CUDA launch, the TensorRT AOT implementation, and the PyTorch op schema +— no hand-written ``aot_fn`` / ``eager_fn`` required. + +Use ``auto_cuda_kernel_plugin`` whenever your kernel follows the standard +calling convention ``(input_ptrs..., extras..., output_ptrs...)`` and fits one +of the built-in geometries: ``Elementwise(layout="flat" | "nd")`` or +``Reduction(reduce_dims=...)``. + +For shape-changing kernels or anything outside that envelope, drop down to +:func:`torch_tensorrt.annotation.manual_cuda_kernel_plugin` (see +``manual_cuda_kernel_plugin_annotation.py``). +""" + +import sys + +import torch +import torch_tensorrt + +if not torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + print( + "[auto_cuda_kernel_plugin_annotation] Skipping example: " + "torch_tensorrt.annotation requires TensorRT QDP plugin support." + ) + sys.exit(0) + +try: + import tensorrt.plugin # noqa: F401 +except ImportError: + print( + "[auto_cuda_kernel_plugin_annotation] Skipping example: " + "tensorrt.plugin unavailable." + ) + sys.exit(0) + +try: + import cuda.core # noqa: F401 +except ImportError: + try: + import cuda.core.experimental # noqa: F401 + except ImportError: + print( + "[auto_cuda_kernel_plugin_annotation] Skipping example: cuda-python " + "is not installed. Install with `pip install cuda-python`." + ) + sys.exit(0) + +import torch_tensorrt.annotation as tta + + +# Calling convention expected by auto_cuda_kernel_plugin: +# (input_ptrs..., extras..., output_ptrs...) + +CU_SIGMOID = """ +extern "C" __global__ void my_sigmoid( + const float* __restrict__ x, int n, float* __restrict__ y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = 1.0f / (1.0f + __expf(-x[i])); +} +""" + + +# SameAs(0) output has the same shape and dtype as input 0. +# Numel("x") pass x.numel() to the kernel as an int extra. +# Elementwise(flat) 1-D launch over the flattened output; any input rank works. + +tta.auto_cuda_kernel_plugin( + "ann_ex::sigmoid", + tta.KernelSpec( + kernel_source=CU_SIGMOID, + kernel_name="my_sigmoid", + inputs=[tta.InputDecl("x")], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[tta.Numel("x")], + geometry=tta.Elementwise(block=(256,), layout="flat"), + ), + supports_dynamic_shapes=True, +) + + +class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.ann_ex.sigmoid(x) + + +if __name__ == "__main__": + x = torch.randn(4, 128, device="cuda", dtype=torch.float32) + ref = torch.sigmoid(x) + + model = Model().cuda().eval() + eager_out = model(x) + print("Eager result matches torch.sigmoid:", torch.allclose(eager_out, ref, atol=1e-4)) + + print("Compiling with Torch-TensorRT...") + trt_model = torch_tensorrt.compile( + model, + inputs=[x], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + + with torch.no_grad(): + for _ in range(5): + out = trt_model(x) + assert torch.allclose(out, ref, atol=1e-2, rtol=1e-2), "Mismatch!" + + print("TRT inference successful - results match torch.sigmoid") diff --git a/examples/dynamo/manual_cuda_kernel_plugin_annotation.py b/examples/dynamo/manual_cuda_kernel_plugin_annotation.py new file mode 100644 index 0000000000..a14e8da1cf --- /dev/null +++ b/examples/dynamo/manual_cuda_kernel_plugin_annotation.py @@ -0,0 +1,160 @@ +""" +.. _manual_cuda_kernel_plugin_annotation: + +Hand-Written Custom Plugin via ``torch_tensorrt.annotation.manual_cuda_kernel_plugin`` +======================================================================================= + +This example demonstrates a shape-changing (non-pointwise) CUDA kernel that +duplicates each input element into two output elements: + y[2*i] = x[i], y[2*i + 1] = x[i] + +Because the output size (``2 * n``) does not match any of the geometries built +into the declarative :func:`torch_tensorrt.annotation.auto_cuda_kernel_plugin` +DSL, we drop down one layer to +:func:`torch_tensorrt.annotation.manual_cuda_kernel_plugin` and hand-write +``eager_fn`` / ``aot_fn`` directly. + +For kernels that *do* fit the DSL (pointwise, ND-grid, reduction) start from +the simpler declarative path shown in +``auto_cuda_kernel_plugin_annotation.py``. +""" + +import sys + +import torch +import torch_tensorrt + +if not torch_tensorrt.ENABLED_FEATURES.qdp_plugin: + print( + "[manual_cuda_kernel_plugin_annotation] Skipping example: " + "torch_tensorrt.annotation requires TensorRT QDP plugin support." + ) + sys.exit(0) + +try: + import tensorrt.plugin as trtp +except ImportError: + print("[manual_cuda_kernel_plugin_annotation] Skipping example: tensorrt.plugin unavailable.") + sys.exit(0) + +try: + from cuda.core import Device as _Device + from cuda.core import LaunchConfig as _LaunchConfig + from cuda.core import Program as _Program + from cuda.core import ProgramOptions as _ProgramOptions + from cuda.core import launch as _cuda_launch +except ImportError: + try: + from cuda.core.experimental import Device as _Device + from cuda.core.experimental import LaunchConfig as _LaunchConfig + from cuda.core.experimental import Program as _Program + from cuda.core.experimental import ProgramOptions as _ProgramOptions + from cuda.core.experimental import launch as _cuda_launch + except ImportError: + print( + "[manual_cuda_kernel_plugin_annotation] Skipping example: cuda-python is not " + "installed. Install with `pip install cuda-python` to run this example." + ) + sys.exit(0) + +import torch_tensorrt.annotation as tta + + +CU_REPEAT2 = """ +extern "C" __global__ void repeat2_kernel( + const float* __restrict__ x, const int n, float* __restrict__ y) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + const float v = x[i]; + y[2 * i] = v; + y[2 * i + 1] = v; + } +} +""" + +_device = _Device() +_device.set_current() +_opts = _ProgramOptions( + std="c++17", arch=f"sm_{_device.arch}", include_path=["/usr/local/cuda/include"] +) +_program = _Program(CU_REPEAT2, code_type="c++", options=_opts) +_module = _program.compile("ptx", name_expressions=("repeat2_kernel",)) +_kernel = _module.get_kernel("repeat2_kernel") + + +class _PTStream: + def __cuda_stream__(self): + return (0, torch.cuda.current_stream().cuda_stream) + + +def _eager_repeat2(x: torch.Tensor) -> torch.Tensor: + if x.dtype != torch.float32: + raise ValueError("This example expects float32 input") + flat = x.contiguous().view(-1) + n = int(flat.numel()) + y = torch.empty((n * 2,), device=x.device, dtype=x.dtype) + block = 256 + grid = max(1, (n + block - 1) // block) + stream = _device.create_stream(_PTStream()) + _cuda_launch( + stream, + _LaunchConfig(grid=(grid,), block=(block,)), + _kernel, + flat.data_ptr(), + n, + y.data_ptr(), + ) + return y + + +def _aot_repeat2(inputs, outputs, tactic): + n = inputs[0].shape_expr.numel() + params = trtp.KernelLaunchParams() + params.grid_x = trtp.cdiv(n, 256) + params.block_x = 256 + params.shared_mem = 0 + extra = trtp.SymIntExprs(1) + extra[0] = trtp.SymInt32(n) + return params, extra + + +@tta.manual_cuda_kernel_plugin( + op_name="ann_ex::repeat2", + kernel_source=CU_REPEAT2, + kernel_name="repeat2_kernel", + eager_fn=_eager_repeat2, + aot_fn=_aot_repeat2, + supports_dynamic_shapes=True, +) +def _repeat2_meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty((x.numel() * 2,), device=x.device, dtype=x.dtype) + + +class Repeat2Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.ann_ex.repeat2(x) + + +if __name__ == "__main__": + x = torch.randn(1024, device="cuda", dtype=torch.float32) + ref = torch.repeat_interleave(x, 2, dim=0) + + model = Repeat2Model().cuda().eval() + eager_out = model(x) + print("Eager result matches repeat_interleave:", torch.allclose(eager_out, ref, atol=1e-4)) + + print("Compiling with Torch-TensorRT...") + with torch_tensorrt.logging.debug(): + trt_model = torch_tensorrt.compile( + model, + inputs=[x], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + + with torch.no_grad(): + for _ in range(5): + out = trt_model(x) + assert torch.allclose(out, ref, atol=1e-2, rtol=1e-2), "Mismatch!" + + print("TRT inference successful - results match repeat_interleave") diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index d127f42690..07619bace8 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -99,6 +99,9 @@ def _register_with_torch() -> None: from torch_tensorrt.dynamo import backend # noqa: F401 from torch_tensorrt import dynamo # noqa: F401 +if ENABLED_FEATURES.qdp_plugin: + from torch_tensorrt import annotation # noqa: F401 + from torch_tensorrt._compile import * # noqa: F403 from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import ( MutableTorchTensorRTModule, diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 7836d63c24..b3aedd109e 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -43,7 +43,7 @@ linked_file_full_path = os.path.join(trtorch_dir, linked_file) linked_file_runtime_full_path = os.path.join(trtorch_dir, linked_file_runtime) -_TENSORRT_RTX = tensorrt._package_name == "tensorrt_rtx" +_TENSORRT_RTX = getattr(tensorrt, "_package_name", "") == "tensorrt_rtx" _TS_FE_AVAIL = os.path.isfile(linked_file_full_path) _TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path) _DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") @@ -57,10 +57,8 @@ elif importlib.util.find_spec("tensorrt.plugin") and importlib.util.find_spec( "tensorrt.plugin._lib" ): - # there is a bug in tensorrt 10.14.* and 10.15.* that causes the plugin to not work, disable it for now - if tensorrt.__version__.startswith("10.15.") or tensorrt.__version__.startswith( - "10.14." - ): + # TensorRT 10.14.* has a known bug that breaks QDP plugins; 10.15.+ works. + if tensorrt.__version__.startswith("10.14."): _QDP_PLUGIN_AVAIL = False else: _QDP_PLUGIN_AVAIL = True diff --git a/py/torch_tensorrt/annotation/__init__.py b/py/torch_tensorrt/annotation/__init__.py new file mode 100644 index 0000000000..31cd62a554 --- /dev/null +++ b/py/torch_tensorrt/annotation/__init__.py @@ -0,0 +1,115 @@ +""" +torch_tensorrt.annotation (experimental) +========================================== +High-level decorators for registering custom CUDA C++ kernels — compiled at +runtime with NVRTC via **cuda-python** — as TensorRT Quick Deployable Plugins +(QDP) with full AOT support. + +The module offers two registration paths. Pick ``auto_cuda_kernel_plugin`` first; drop +down to ``manual_cuda_kernel_plugin`` only when your kernel doesn't fit the Elementwise / +Reduction conventions. + +``auto_cuda_kernel_plugin`` *(recommended starting point)* + Fully declarative. Describe inputs, outputs, extras, and launch geometry + via :class:`KernelSpec` dataclasses; the meta / eager / aot functions and + the PyTorch schema are all derived for you. Covers pointwise, ND-grid, + and reduction kernels out of the box. + +``manual_cuda_kernel_plugin`` + One-shot decorator when you want to hand-write ``aot_fn`` and ``eager_fn``. + Use this for shape-changing kernels, multi-output kernels, or anything + outside the declarative DSL. Registers the PyTorch op and the TRT plugin + in a single step. + +``custom_plugin`` / ``cuda_python`` + Lower-level building blocks used internally by ``manual_cuda_kernel_plugin``. Use + ``cuda_python`` to construct a reusable :class:`CudaPythonSpec`, then pass + it to ``custom_plugin`` to register. + +``pointwise_aot`` / ``pointwise_eager`` + Helper factories that produce the two callables for unary pointwise + kernels, reducing cuda-python boilerplate when using ``manual_cuda_kernel_plugin`` + directly. + +Minimal example — ``auto_cuda_kernel_plugin`` (derives meta/eager/aot/schema):: + + import torch, torch_tensorrt + import torch_tensorrt.annotation as tta + + cu_code = \"\"\" + extern "C" __global__ void my_relu(const float* x, int n, float* y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = x[i] > 0.f ? x[i] : 0.f; + } + \"\"\" + + tta.auto_cuda_kernel_plugin( + "myns::relu", + tta.KernelSpec( + kernel_source=cu_code, + kernel_name="my_relu", + inputs=[tta.InputDecl("x")], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[tta.Numel("x")], + geometry=tta.Elementwise(block=(256,), layout="flat"), + ), + supports_dynamic_shapes=True, + ) + + class M(torch.nn.Module): + def forward(self, x): return torch.ops.myns.relu(x) + + trt = torch_tensorrt.compile( + M().cuda().eval(), + inputs=[torch.randn(1024, device="cuda")], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + +For kernels outside the Elementwise / Reduction shape conventions, drop down +to :func:`manual_cuda_kernel_plugin` and supply ``aot_fn`` / ``eager_fn`` directly. +""" + +from torch_tensorrt.annotation._specs import CudaPythonSpec +from torch_tensorrt.annotation._custom_plugin import ( + manual_cuda_kernel_plugin, + cuda_python, + custom_plugin, + pointwise_aot, + pointwise_eager, +) +from torch_tensorrt.annotation._kernel_spec import ( + Custom, + DimSize, + Elementwise, + InputDecl, + KernelSpec, + Numel, + OutputDecl, + ReduceDims, + Reduction, + SameAs, + ScalarInput, +) +from torch_tensorrt.annotation._kernel_plugin import auto_cuda_kernel_plugin + +__all__ = [ + "CudaPythonSpec", + "Custom", + "DimSize", + "Elementwise", + "InputDecl", + "KernelSpec", + "Numel", + "OutputDecl", + "ReduceDims", + "Reduction", + "SameAs", + "ScalarInput", + "manual_cuda_kernel_plugin", + "cuda_python", + "custom_plugin", + "auto_cuda_kernel_plugin", + "pointwise_aot", + "pointwise_eager", +] diff --git a/py/torch_tensorrt/annotation/_custom_plugin/__init__.py b/py/torch_tensorrt/annotation/_custom_plugin/__init__.py new file mode 100644 index 0000000000..ad0104152f --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/__init__.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import logging +from typing import Callable, List, Optional + +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority +from torch_tensorrt.annotation._specs import CudaPythonSpec, _default_cuda_include_paths + +_LOGGER = logging.getLogger(__name__) + + +def cuda_python( + kernel_source: str, + kernel_name: str, + aot_fn: Optional[Callable] = None, + eager_fn: Optional[Callable] = None, + include_paths: Optional[List[str]] = None, + compile_std: str = "c++17", + arch_override: Optional[str] = None, +) -> CudaPythonSpec: + """Create a :class:`CudaPythonSpec` for a CUDA C++ kernel compiled with NVRTC. + + Args: + kernel_source: Raw CUDA C++ source string containing the ``__global__`` kernel. + kernel_name: Name of the ``extern "C" __global__`` function to invoke. + aot_fn: Callable ``(inputs, outputs, tactic) -> (KernelLaunchParams, SymExprs | None)`` + that returns launch parameters for the TensorRT AOT implementation. + May be set later by assigning to ``spec.aot_fn``. + eager_fn: Optional callable used as the CUDA device implementation of the + PyTorch custom op. Required when ``register_torch_op=True`` in + :func:`custom_plugin`. Signature must match the op schema. + include_paths: Extra ``#include`` search paths passed to NVRTC. + Defaults to ``$CUDA_HOME/include`` (or ``$CUDA_PATH/include``), + falling back to ``["/usr/local/cuda/include"]``. + compile_std: C++ standard flag forwarded to NVRTC (default ``"c++17"``). + arch_override: Override the GPU architecture string (e.g. ``"sm_86"``). + When ``None`` the current device's arch is used. + + Returns: + A :class:`CudaPythonSpec` instance ready for use with :func:`custom_plugin`. + """ + if not ENABLED_FEATURES.qdp_plugin: + raise RuntimeError( + "TensorRT QDP plugins are not available. " + "Requires TensorRT >= 10.7.0 (and not 10.14.x)." + ) + return CudaPythonSpec( + kernel_source=kernel_source, + kernel_name=kernel_name, + aot_fn=aot_fn, + eager_fn=eager_fn, + include_paths=include_paths if include_paths is not None else _default_cuda_include_paths(), + compile_std=compile_std, + arch_override=arch_override, + ) + + +def custom_plugin( + op_name: str, + spec: CudaPythonSpec, + supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, + priority: ConverterPriority = ConverterPriority.STANDARD, + capability_validator: Optional[Callable] = None, + schema: Optional[str] = None, +) -> Callable: + """Decorator that registers a CUDA kernel as a TensorRT QDP plugin. + + The decorated function acts as the **meta / fake implementation** (shape and + dtype inference for TorchDynamo tracing). It must mirror the op's signature + with proper type annotations and return ``torch.empty_*`` tensors of the + correct shape. + + ``spec.eager_fn`` is registered as the CUDA device implementation of the + generated PyTorch custom op, so it must accept the same positional arguments + and return the same outputs as the meta function. + + ``spec.aot_fn`` is registered as the TensorRT AOT implementation and must + have the signature:: + + def aot_fn(inputs: list[trtp.TensorDesc], + outputs: tuple[trtp.TensorDesc], + tactic: int) -> tuple[trtp.KernelLaunchParams, trtp.SymExprs | None]: + + Example:: + + cu_code = \"\"\" + extern "C" __global__ void my_sigmoid(const float* x, int n, float* y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = 1.0f / (1.0f + __expf(-x[i])); + } + \"\"\" + + def _eager(x: torch.Tensor) -> torch.Tensor: + ... # cuda-python launch + + def _aot(inputs, outputs, tactic): + import tensorrt.plugin as trtp + N = inputs[0].shape_expr.numel() + p = trtp.KernelLaunchParams() + p.grid_x, p.block_x, p.shared_mem = trtp.cdiv(N, 256), 256, 0 + extra = trtp.SymIntExprs(1) + extra[0] = trtp.SymInt32(N) + return p, extra + + spec = tta.cuda_python(cu_code, "my_sigmoid", aot_fn=_aot, eager_fn=_eager) + + @tta.custom_plugin("myns::sigmoid", spec, supports_dynamic_shapes=True) + def _(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + Args: + op_name: Plugin id in ``"namespace::name"`` format. + spec: :class:`CudaPythonSpec` with ``aot_fn`` and ``eager_fn`` populated. + supports_dynamic_shapes: Pass ``True`` if the kernel handles dynamic shapes. + requires_output_allocator: Set ``True`` for data-dependent output shapes. + priority: Converter priority in the registry. + capability_validator: Optional ``(Node, CompilationSettings) -> bool`` guard. + schema: Optional explicit PyTorch op schema suffix in Torch schema syntax + (for example, ``"(Tensor x) -> Tensor"``). When omitted, the schema + is inferred from the decorated meta function's type hints. + """ + from torch_tensorrt.annotation._custom_plugin._descriptor import ( + register_cuda_python_plugin, + ) + + def decorator(meta_fn: Callable) -> Callable: + register_cuda_python_plugin( + op_name=op_name, + spec=spec, + meta_fn=meta_fn, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + priority=priority, + capability_validator=capability_validator, + register_torch_op=True, + schema=schema, + ) + return meta_fn + + return decorator + + +def pointwise_aot( + block_size: int = 256, + input_index: int = 0, +) -> Callable: + """Create a common AOT launch-config function for pointwise kernels. + + The generated function computes ``N = inputs[input_index].shape_expr.numel()`` + and returns launch parameters using 1D blocks with one symbolic extra arg ``N``. + + Args: + block_size: CUDA block size used for ``block_x``. + input_index: TensorDesc index from which ``N`` is derived. + """ + if block_size <= 0: + raise ValueError(f"block_size must be > 0, got {block_size}") + if input_index < 0: + raise ValueError(f"input_index must be >= 0, got {input_index}") + + def _aot(inputs, outputs, tactic): + import tensorrt.plugin as trtp + + if input_index >= len(inputs): + raise IndexError( + f"input_index={input_index} is out of range for inputs of length {len(inputs)}" + ) + n = inputs[input_index].shape_expr.numel() + params = trtp.KernelLaunchParams() + params.grid_x = trtp.cdiv(n, block_size) + params.block_x = block_size + params.shared_mem = 0 + extra = trtp.SymIntExprs(1) + extra[0] = trtp.SymInt32(n) + return params, extra + + return _aot + + +def pointwise_eager( + kernel_source: str, + kernel_name: str, + include_paths: Optional[List[str]] = None, + compile_std: str = "c++17", + arch_override: Optional[str] = None, + block_size: int = 256, +) -> Callable: + """Create an eager CUDA implementation for unary pointwise kernels. + + The generated function assumes kernel signature: + ``kernel(const T* x, int n, T* y)`` + where ``n`` is the flattened element count. + + Args: + kernel_source: CUDA C++ source containing ``kernel_name``. + kernel_name: Name of the ``extern "C" __global__`` kernel function. + include_paths: Optional NVRTC include paths. + compile_std: C++ standard passed to NVRTC. + arch_override: Optional explicit arch string (for example ``"sm_90"``). + block_size: CUDA block size for launch. + """ + if block_size <= 0: + raise ValueError(f"block_size must be > 0, got {block_size}") + + resolved_include_paths = ( + include_paths if include_paths is not None else _default_cuda_include_paths() + ) + runtime_cache = {} + + def _ensure_compiled() -> None: + if runtime_cache: + return + from torch_tensorrt.annotation._custom_plugin._nvrtc import compile_to_ptx + + _ptx, device, kernel = compile_to_ptx( + kernel_source=kernel_source, + kernel_name=kernel_name, + include_paths=resolved_include_paths, + compile_std=compile_std, + arch_override=arch_override, + ) + try: + from cuda.core import LaunchConfig, launch + except ImportError: + from cuda.core.experimental import LaunchConfig, launch + + runtime_cache["device"] = device + runtime_cache["kernel"] = kernel + runtime_cache["launch"] = launch + runtime_cache["LaunchConfig"] = LaunchConfig + + def _eager(x): + import torch + + _ensure_compiled() + + if not x.is_cuda: + raise ValueError("pointwise_eager expects a CUDA tensor input") + + y = torch.empty_like(x) + n = int(x.numel()) + grid = max(1, (n + block_size - 1) // block_size) + + class _PTStream: + def __cuda_stream__(self): + return (0, torch.cuda.current_stream().cuda_stream) + + device = runtime_cache["device"] + launch = runtime_cache["launch"] + launch_config = runtime_cache["LaunchConfig"] + kernel = runtime_cache["kernel"] + stream = device.create_stream(_PTStream()) + launch( + stream, + launch_config(grid=(grid,), block=(block_size,)), + kernel, + x.data_ptr(), + n, + y.data_ptr(), + ) + return y + + return _eager + + +def manual_cuda_kernel_plugin( + op_name: str, + kernel_source: str, + kernel_name: str, + aot_fn: Callable, + eager_fn: Optional[Callable] = None, + include_paths: Optional[List[str]] = None, + compile_std: str = "c++17", + arch_override: Optional[str] = None, + supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, + priority: ConverterPriority = ConverterPriority.STANDARD, + capability_validator: Optional[Callable] = None, + schema: Optional[str] = None, +) -> Callable: + """One-shot decorator for CUDA kernel + custom plugin registration. + + This is a convenience wrapper equivalent to ``cuda_python(...)`` followed by + ``custom_plugin(...)``. + """ + spec = cuda_python( + kernel_source=kernel_source, + kernel_name=kernel_name, + aot_fn=aot_fn, + eager_fn=eager_fn, + include_paths=include_paths, + compile_std=compile_std, + arch_override=arch_override, + ) + return custom_plugin( + op_name=op_name, + spec=spec, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + priority=priority, + capability_validator=capability_validator, + schema=schema, + ) diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py b/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py new file mode 100644 index 0000000000..6a3e0e4746 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import inspect +import logging +from typing import Callable, List, Optional, get_type_hints + +import torch +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority +from torch_tensorrt.dynamo.conversion.plugins import generate_plugin, generate_plugin_converter + +from torch_tensorrt.annotation._specs import CudaPythonSpec + +_LOGGER = logging.getLogger(__name__) + + +def _patch_trt_shape_expr_reflected_ops() -> None: + # TRT's trtp.ShapeExpr defines __mul__ / __add__ but not the reflected + # __rmul__ / __radd__. torch_tensorrt lowers meta-fn shape expressions via + # sympy.lambdify(..., "math"), which emits canonical forms like + # ``lambda N: 2*N`` — at runtime Python does ``int * ShapeExpr``, falls + # back to ``ShapeExpr.__rmul__``, and crashes with TypeError. Since these + # ops are commutative we can safely alias the reflected version. + try: + import tensorrt.plugin as trtp + except ImportError: + return + cls = getattr(trtp, "ShapeExpr", None) + if cls is None: + return + for fwd, rev in (("__mul__", "__rmul__"), ("__add__", "__radd__")): + if hasattr(cls, fwd) and not hasattr(cls, rev): + try: + setattr(cls, rev, getattr(cls, fwd)) + except (AttributeError, TypeError): + pass + + +_patch_trt_shape_expr_reflected_ops() + +# Keep Library instances alive – torch frees op registrations when a Library is GC'd. +_LIVE_LIBS: List[torch.library.Library] = [] + +_TORCH_TYPE_TO_SCHEMA = { + torch.Tensor: "Tensor", + int: "int", + float: "float", + bool: "bool", + str: "str", +} + + +def _infer_schema(fn: Callable) -> str: + """Derive a TorchScript schema like '(Tensor x, int n) -> Tensor' from type hints.""" + try: + hints = get_type_hints(fn) + except Exception: + hints = {} + + params = list(inspect.signature(fn).parameters.keys()) + args_str = ", ".join( + "{} {}".format( + _TORCH_TYPE_TO_SCHEMA.get(hints.get(p, torch.Tensor), "Tensor"), p + ) + for p in params + ) + + ret = hints.get("return", torch.Tensor) + origin = getattr(ret, "__origin__", None) + if origin is tuple: + ret_str = "({})".format( + ", ".join( + _TORCH_TYPE_TO_SCHEMA.get(t, "Tensor") for t in ret.__args__ + ) + ) + else: + ret_str = _TORCH_TYPE_TO_SCHEMA.get(ret, "Tensor") + + return f"({args_str}) -> {ret_str}" + + +def _torch_op_already_registered(op_name: str) -> bool: + """Return True if ``op_name`` is already known to the torch dispatcher.""" + try: + return bool(torch._C._jit_get_schemas_for_operator(op_name)) + except RuntimeError: + return False + + +def _register_pytorch_op( + op_name: str, + meta_fn: Callable, + eager_fn: Optional[Callable], + schema: Optional[str] = None, +) -> None: + """Register a new PyTorch custom op using torch.library.Library. + + Idempotent: if ``op_name`` is already registered (e.g. a prior call in the + same process), this is a no-op rather than re-defining and raising from + the dispatcher. + """ + if _torch_op_already_registered(op_name): + _LOGGER.debug("PyTorch op %s already registered; skipping re-register", op_name) + return + + ns, name = op_name.split("::") + schema_str = schema if schema is not None else _infer_schema(meta_fn) + + lib = torch.library.Library(ns, "FRAGMENT") + lib.define(f"{name}{schema_str}") + _LIVE_LIBS.append(lib) + + if eager_fn is not None: + lib.impl(name, eager_fn, "CUDA") + + torch.library.register_fake(op_name)(meta_fn) + _LOGGER.debug("Registered PyTorch op %s schema: %s%s", op_name, name, schema_str) + + +def _register_aot_impl(op_name: str, ptx: bytes, spec: CudaPythonSpec) -> None: + """Dynamically build a correctly-typed aot_impl and register it with trtp.""" + import tensorrt.plugin as trtp + from typing import Tuple, Union # noqa: F401 – used in annotations dict + + ns, name = op_name.split("::") + torch_op = getattr(getattr(torch.ops, ns), name) + schema = torch_op._schemas[""] + + tensor_arg_names = [ + arg.name + for arg in schema.arguments + if arg.type.isSubtypeOf(torch._C.TensorType.get()) + ] + + ptx_str: str = ptx.decode("utf-8") if isinstance(ptx, bytes) else ptx + kernel_name = spec.kernel_name + user_aot_fn = spec.aot_fn + + # Build the aot_impl function body with the correct positional arg names so + # trtp.aot_impl can match them to the registered descriptor. + sig = ", ".join(tensor_arg_names + ["outputs", "tactic"]) + fn_body = f"""\ +def _aot_impl({sig}): + inputs = [{", ".join(tensor_arg_names)}] + result = _user_aot_fn(inputs, outputs, tactic) + if isinstance(result, tuple) and len(result) == 2: + launch_params, extra_args = result + else: + launch_params, extra_args = result, None + if extra_args is None: + extra_args = _trtp.SymIntExprs(0) + return (_kernel_name, _ptx_str, launch_params, extra_args) +""" + + fn_globals = { + "_user_aot_fn": user_aot_fn, + "_kernel_name": kernel_name, + "_ptx_str": ptx_str, + "_trtp": trtp, + } + local_ns: dict = {} + exec(compile(fn_body, "", "exec"), fn_globals, local_ns) + aot_fn = local_ns["_aot_impl"] + + aot_fn.__annotations__ = { + n: trtp.TensorDesc for n in tensor_arg_names + } + aot_fn.__annotations__["outputs"] = Tuple[trtp.TensorDesc] # type: ignore[name-defined] + aot_fn.__annotations__["tactic"] = int + aot_fn.__annotations__["return"] = Tuple[ # type: ignore[name-defined] + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs # type: ignore[name-defined] + ] + + trtp.aot_impl(op_name)(aot_fn) + _LOGGER.debug("Registered AOT impl for %s", op_name) + + +def register_cuda_python_plugin( + op_name: str, + spec: CudaPythonSpec, + meta_fn: Optional[Callable], + supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, + priority: ConverterPriority = ConverterPriority.STANDARD, + capability_validator: Optional[Callable] = None, + register_torch_op: bool = True, + schema: Optional[str] = None, + precompiled_ptx: Optional[bytes] = None, +) -> None: + """Register a NVRTC-compiled CUDA kernel as a TensorRT QDP plugin end-to-end. + + Steps performed: + 1. Compile kernel source to PTX via NVRTC (skipped if ``precompiled_ptx`` is passed). + 2. Optionally register the PyTorch custom op (define + fake impl). + 3. Register the TRT plugin descriptor + JIT impl via generate_plugin(). + 4. Register the AOT impl with the compiled PTX. + 5. Register the Torch-TensorRT converter via generate_plugin_converter(). + + ``precompiled_ptx`` lets higher-level entry points (e.g. ``auto_cuda_kernel_plugin``) + avoid a redundant second NVRTC pass when they already compiled the source + to build an eager kernel handle. + """ + if spec.aot_fn is None: + raise ValueError( + f"CudaPythonSpec.aot_fn must be set before registering plugin '{op_name}'. " + "Pass aot_fn= to cuda_python() or assign spec.aot_fn directly." + ) + + if precompiled_ptx is not None: + ptx = precompiled_ptx + else: + from torch_tensorrt.annotation._custom_plugin._nvrtc import compile_to_ptx + + ptx, _device, _kernel = compile_to_ptx( + spec.kernel_source, + spec.kernel_name, + spec.include_paths, + spec.compile_std, + spec.arch_override, + ) + + if register_torch_op: + if meta_fn is None: + raise ValueError( + "meta_fn is required when register_torch_op=True. " + "Provide the fake/meta implementation as the decorated function." + ) + _register_pytorch_op(op_name, meta_fn, spec.eager_fn, schema=schema) + + generate_plugin(op_name) + _register_aot_impl(op_name, ptx, spec) + generate_plugin_converter( + op_name, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + use_aot_if_available=True, + ) + + _LOGGER.info("cuda-python QDP plugin '%s' registered successfully", op_name) diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py b/py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py new file mode 100644 index 0000000000..ec4a706404 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import logging +from typing import Any, Optional, Tuple + +_LOGGER = logging.getLogger(__name__) + + +def _cuda_core_imports() -> Tuple[Any, Any, Any, Any, Any]: + """Import cuda.core symbols, accepting both the stable and legacy namespaces.""" + try: + from cuda.core import Device, LaunchConfig, Program, ProgramOptions, launch + + return Device, Program, ProgramOptions, launch, LaunchConfig + except ImportError: + pass + try: + from cuda.core.experimental import ( + Device, + LaunchConfig, + Program, + ProgramOptions, + launch, + ) + + return Device, Program, ProgramOptions, launch, LaunchConfig + except ImportError: + raise ImportError( + "cuda-python is required for cuda_python plugins. " + "Install it with: pip install cuda-python" + ) + + +def compile_to_ptx( + kernel_source: str, + kernel_name: str, + include_paths: list, + compile_std: str = "c++17", + arch_override: Optional[str] = None, +) -> Tuple[bytes, Any, Any]: + """Compile CUDA C++ source to PTX using NVRTC via cuda-python. + + Returns: + (ptx_bytes, device, kernel_object) + """ + Device, Program, ProgramOptions, _launch, _LaunchConfig = _cuda_core_imports() + + device = Device() + device.set_current() + arch = arch_override if arch_override else f"sm_{device.arch}" + + options = ProgramOptions( + std=compile_std, + arch=arch, + include_path=include_paths, + ) + program = Program(kernel_source, code_type="c++", options=options) + module = program.compile("ptx", name_expressions=(kernel_name,)) + ptx: bytes = module.code + kernel = module.get_kernel(kernel_name) + _LOGGER.debug( + "Compiled kernel '%s' to PTX for %s (%d bytes)", kernel_name, arch, len(ptx) + ) + return ptx, device, kernel diff --git a/py/torch_tensorrt/annotation/_kernel_plugin.py b/py/torch_tensorrt/annotation/_kernel_plugin.py new file mode 100644 index 0000000000..32ad484db1 --- /dev/null +++ b/py/torch_tensorrt/annotation/_kernel_plugin.py @@ -0,0 +1,676 @@ +"""Derivation engine behind :func:`auto_cuda_kernel_plugin`. + +Given a :class:`KernelSpec`, this builds: + +* a PyTorch meta/fake impl from each :class:`OutputDecl.shape`, +* a PyTorch eager impl that launches the compiled kernel via cuda-python, +* a TensorRT AOT impl with symbolic launch params and extras, +* a PyTorch op schema string. + +and hands them off to :func:`register_cuda_python_plugin`. +""" +from __future__ import annotations + +import logging +import textwrap +from typing import Callable, Dict, List, Optional, Tuple + +import torch + +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority +from torch_tensorrt.annotation._kernel_spec import ( + Custom, + DimSize, + Elementwise, + ExtraArg, + InputDecl, + KernelSpec, + Numel, + OutputDecl, + ReduceDims, + Reduction, + SameAs, + ScalarInput, +) +from torch_tensorrt.annotation._specs import ( + CudaPythonSpec, + _default_cuda_include_paths, +) + +_LOGGER = logging.getLogger(__name__) + + +# ============================================================================= +# ShapeRel evaluation (torch + TRT variants) +# ============================================================================= + + +def _resolve_axis(axis: int, ndim: int) -> int: + return axis if axis >= 0 else ndim + axis + + +def _tensor_input_decls(inputs) -> List[InputDecl]: + """Return only the tensor inputs, preserving order.""" + return [d for d in inputs if isinstance(d, InputDecl)] + + +def _torch_output_shape_dtype( + decl: OutputDecl, + input_tensors: List[torch.Tensor], + input_decls: List[InputDecl], +) -> Tuple[Tuple[int, ...], torch.dtype]: + """Return (shape, dtype) for one output given concrete input tensors. + + ``input_decls`` and ``input_tensors`` must be aligned lists of tensor-only + inputs; scalar inputs are filtered out by the caller. + """ + rel = decl.shape + name_to_idx = {d.name: i for i, d in enumerate(input_decls)} + + if isinstance(rel, SameAs): + src = input_tensors[rel.input_idx] + dtype = ( + input_tensors[name_to_idx[decl.dtype_from]].dtype + if decl.dtype_from is not None + else src.dtype + ) + return tuple(src.shape), dtype + + if isinstance(rel, ReduceDims): + src = input_tensors[rel.input_idx] + dims = {_resolve_axis(d, src.ndim) for d in rel.dims} + new_shape = [] + for i, s in enumerate(src.shape): + if i in dims: + if rel.keepdim: + new_shape.append(1) + else: + new_shape.append(int(s)) + dtype = ( + input_tensors[name_to_idx[decl.dtype_from]].dtype + if decl.dtype_from is not None + else src.dtype + ) + return tuple(new_shape), dtype + + raise TypeError(f"Unsupported ShapeRel: {rel!r}") + + +# ============================================================================= +# Launch geometry: concrete (eager) + symbolic (aot) +# ============================================================================= + + +def _cdiv_int(a: int, b: int) -> int: + return (a + b - 1) // b + + +def _compute_eager_launch( + geom, + outputs: List[torch.Tensor], + inputs: List[torch.Tensor], +) -> Tuple[Tuple[int, int, int], Tuple[int, int, int]]: + """Return (grid, block) both padded to 3-tuples for cuda.core.LaunchConfig.""" + if isinstance(geom, Elementwise): + out = outputs[0] + if geom.layout == "flat": + n = int(out.numel()) + bx = int(geom.block[0]) + grid = (_cdiv_int(n, bx), 1, 1) + block = (bx, 1, 1) + return grid, block + + # layout == "nd" + block_dims = tuple(int(b) for b in geom.block) + out_shape = tuple(int(s) for s in out.shape) + b_ndim = len(block_dims) + if len(out_shape) < b_ndim: + raise ValueError( + f"Elementwise(layout='nd', block={block_dims}) needs output ndim >= " + f"{b_ndim}, got output shape {out_shape}" + ) + # block[0] -> innermost axis; block[i] -> axis out_shape[-(i+1)] + inner_axes = [out_shape[-(i + 1)] for i in range(b_ndim)] + grid_xyz = [_cdiv_int(axis, block_dims[i]) for i, axis in enumerate(inner_axes)] + while len(grid_xyz) < 3: + grid_xyz.append(1) + outer_axes = out_shape[: len(out_shape) - b_ndim] + if outer_axes: + prod = 1 + for s in outer_axes: + prod *= s + grid_xyz[2] *= prod + block_padded = tuple(block_dims) + (1,) * (3 - b_ndim) + return tuple(grid_xyz), block_padded # type: ignore[return-value] + + if isinstance(geom, Reduction): + # One block per "row" = numel(input[0]) / product(reduce_dim_sizes). + # This is independent of whether the output is smaller (sum/max) or + # same-shape (softmax/layernorm): the row count is a property of the + # *input*'s non-reduced axes. + src = inputs[0] + reduce_axes = {_resolve_axis(d, src.ndim) for d in geom.reduce_dims} + rows = 1 + for i, s in enumerate(src.shape): + if i not in reduce_axes: + rows *= int(s) + grid = (rows, 1, 1) + block = (int(geom.block_size), 1, 1) + return grid, block + + if isinstance(geom, Custom): + raise RuntimeError( + "Custom geometry has no eager equivalent. Provide eager_fn directly " + "via manual_cuda_kernel_plugin() / custom_plugin() instead of auto_cuda_kernel_plugin()." + ) + + raise TypeError(f"Unsupported Geometry: {geom!r}") + + +def _compute_aot_launch( + geom, + inputs_desc, + outputs_desc, +): + """Return a ``trtp.KernelLaunchParams`` matching the eager launch.""" + import tensorrt.plugin as trtp + + params = trtp.KernelLaunchParams() + params.shared_mem = 0 + + if isinstance(geom, Elementwise): + out = outputs_desc[0] + if geom.layout == "flat": + n = out.shape_expr.numel() + bx = int(geom.block[0]) + params.grid_x = trtp.cdiv(n, bx) + params.block_x = bx + return params + + block_dims = tuple(int(b) for b in geom.block) + out_shape = out.shape_expr + out_ndim = len(out_shape) + b_ndim = len(block_dims) + if out_ndim < b_ndim: + raise ValueError( + f"Elementwise(layout='nd', block={block_dims}) needs output ndim >= " + f"{b_ndim}, got ndim={out_ndim}" + ) + + # Map block[0] -> last axis, etc. Fill grid_x/y/z as available. + grid_axes = [out_shape[out_ndim - 1 - i] for i in range(b_ndim)] + if b_ndim >= 1: + params.grid_x = trtp.cdiv(grid_axes[0], block_dims[0]) + params.block_x = block_dims[0] + if b_ndim >= 2: + params.grid_y = trtp.cdiv(grid_axes[1], block_dims[1]) + params.block_y = block_dims[1] + if b_ndim >= 3: + params.grid_z = trtp.cdiv(grid_axes[2], block_dims[2]) + params.block_z = block_dims[2] + + outer_ndim = out_ndim - b_ndim + if outer_ndim > 0: + prod = out_shape[0] + for i in range(1, outer_ndim): + prod = prod * out_shape[i] + if b_ndim < 3: + params.grid_z = prod + else: + params.grid_z = params.grid_z * prod + return params + + if isinstance(geom, Reduction): + src = inputs_desc[0] + src_shape = src.shape_expr + src_ndim = len(src_shape) + reduce_axes = {_resolve_axis(d, src_ndim) for d in geom.reduce_dims} + rows = None + for i in range(src_ndim): + if i in reduce_axes: + continue + rows = src_shape[i] if rows is None else rows * src_shape[i] + params.grid_x = rows if rows is not None else 1 + params.block_x = int(geom.block_size) + return params + + if isinstance(geom, Custom): + raise RuntimeError("Custom geometry routed through dedicated path") + + raise TypeError(f"Unsupported Geometry: {geom!r}") + + +# ============================================================================= +# Extras packing +# ============================================================================= + + +def _pack_extras_eager( + extras: List[ExtraArg], + input_by_name: Dict[str, torch.Tensor], +) -> List[int]: + out = [] + for e in extras: + src = input_by_name[e.input_name] + if isinstance(e, Numel): + out.append(int(src.numel())) + elif isinstance(e, DimSize): + axis = _resolve_axis(e.axis, src.ndim) + out.append(int(src.shape[axis])) + else: + raise TypeError(f"Unsupported ExtraArg: {e!r}") + return out + + +def _pack_extras_aot(extras: List[ExtraArg], input_desc_by_name): + import tensorrt.plugin as trtp + + if not extras: + return trtp.SymIntExprs(0) + + sym_exprs = trtp.SymIntExprs(len(extras)) + for i, e in enumerate(extras): + src = input_desc_by_name[e.input_name] + if isinstance(e, Numel): + sym_exprs[i] = trtp.SymInt32(src.shape_expr.numel()) + elif isinstance(e, DimSize): + axis = _resolve_axis(e.axis, len(src.shape_expr)) + sym_exprs[i] = trtp.SymInt32(src.shape_expr[axis]) + else: + raise TypeError(f"Unsupported ExtraArg: {e!r}") + return sym_exprs + + +# ============================================================================= +# Schema + dynamic function builders +# ============================================================================= + + +_PYTYPE_TO_SCHEMA = {float: "float", int: "int", bool: "bool"} + + +def _schema_type_for_input(decl) -> str: + if isinstance(decl, ScalarInput): + if decl.py_type not in _PYTYPE_TO_SCHEMA: + raise ValueError( + f"ScalarInput {decl.name!r}: py_type must be float, int, or bool, " + f"got {decl.py_type!r}" + ) + return _PYTYPE_TO_SCHEMA[decl.py_type] + return "Tensor" + + +def _annot_type_for_input(decl): + if isinstance(decl, ScalarInput): + return decl.py_type + return torch.Tensor + + +def _build_schema(spec: KernelSpec) -> str: + args = ", ".join(f"{_schema_type_for_input(d)} {d.name}" for d in spec.inputs) + if len(spec.outputs) == 1: + ret = "Tensor" + else: + ret = "(" + ", ".join("Tensor" for _ in spec.outputs) + ")" + return f"({args}) -> {ret}" + + +def _make_positional_fn(fn: Callable, input_decls) -> Callable: + """Wrap ``fn`` so it carries a proper positional signature. + + ``torch.library.Library.impl`` and ``register_fake`` both introspect the + wrapped callable; a generic ``*args`` function wouldn't satisfy them. + The synthesized wrapper types each parameter per its :class:`InputDecl` / + :class:`ScalarInput` kind so the torch dispatcher matches the schema. + """ + param_names = [d.name for d in input_decls] + sig_pieces = [] + annotations = {} + for d in input_decls: + if isinstance(d, ScalarInput): + sig_pieces.append(f"{d.name}: '{d.py_type.__name__}'") + annotations[d.name] = d.py_type + else: + sig_pieces.append(f"{d.name}: 'torch.Tensor'") + annotations[d.name] = torch.Tensor + sig_src = ", ".join(sig_pieces) + body = textwrap.dedent( + f""" + def _wrapper({sig_src}) -> 'torch.Tensor': + return _fn({", ".join(param_names)}) + """ + ) + ns: dict = {"_fn": fn, "torch": torch} + exec(compile(body, "", "exec"), ns) + wrapper = ns["_wrapper"] + wrapper.__annotations__ = dict(annotations) + wrapper.__annotations__["return"] = torch.Tensor + return wrapper + + +# ============================================================================= +# Kernel compilation (wraps cuda-python) +# ============================================================================= + + +def _compile_kernel(spec: KernelSpec): + """Compile spec.kernel_source to PTX + get a loadable kernel object.""" + try: + from cuda.core import ( + Device, + Program, + ProgramOptions, + ) + except ImportError: + from cuda.core.experimental import ( + Device, + Program, + ProgramOptions, + ) + + device = Device() + device.set_current() + arch = spec.arch_override if spec.arch_override else f"sm_{device.arch}" + include_paths = ( + list(spec.include_paths) + if spec.include_paths is not None + else _default_cuda_include_paths() + ) + options = ProgramOptions(std=spec.compile_std, arch=arch, include_path=include_paths) + program = Program(spec.kernel_source, code_type="c++", options=options) + module = program.compile("ptx", name_expressions=(spec.kernel_name,)) + ptx: bytes = module.code + kernel = module.get_kernel(spec.kernel_name) + return ptx, device, kernel + + +# ============================================================================= +# Eager / meta / aot function factories +# ============================================================================= + + +def _split_inputs(all_args, input_specs): + """Partition positional args into (tensor_inputs, scalar_values) aligned + with the tensor InputDecls and ScalarInputs respectively. + """ + tensors: List[torch.Tensor] = [] + tensor_decls: List[InputDecl] = [] + scalars: Dict[str, object] = {} + for decl, val in zip(input_specs, all_args): + if isinstance(decl, ScalarInput): + scalars[decl.name] = val + else: + tensors.append(val) + tensor_decls.append(decl) + return tensors, tensor_decls, scalars + + +def _make_meta_fn(spec: KernelSpec) -> Callable: + input_specs = list(spec.inputs) + output_decls = list(spec.outputs) + + def _meta(*args): + tensors, tensor_decls, _scalars = _split_inputs(args, input_specs) + device = tensors[0].device if tensors else torch.device("cuda") + outs = [] + for odecl in output_decls: + shape, dtype = _torch_output_shape_dtype(odecl, tensors, tensor_decls) + outs.append(torch.empty(shape, dtype=dtype, device=device)) + return outs[0] if len(outs) == 1 else tuple(outs) + + return _make_positional_fn(_meta, input_specs) + + +def _make_eager_fn(spec: KernelSpec, kernel_obj, device) -> Callable: + try: + from cuda.core import LaunchConfig + from cuda.core import launch as cuda_launch + except ImportError: + from cuda.core.experimental import LaunchConfig + from cuda.core.experimental import launch as cuda_launch + + input_specs = list(spec.inputs) + output_decls = list(spec.outputs) + extras = list(spec.extras) + + class _PTStream: + def __cuda_stream__(self): # noqa: D401 + return (0, torch.cuda.current_stream().cuda_stream) + + def _eager(*args): + tensors, tensor_decls, _scalars = _split_inputs(args, input_specs) + + outs: List[torch.Tensor] = [] + for odecl in output_decls: + shape, dtype = _torch_output_shape_dtype(odecl, tensors, tensor_decls) + outs.append(torch.empty(shape, dtype=dtype, device=tensors[0].device)) + + grid, block = _compute_eager_launch(spec.geometry, outs, tensors) + + input_by_name = {d.name: t for d, t in zip(tensor_decls, tensors)} + extra_vals = _pack_extras_eager(extras, input_by_name) + + # Kernel arg order: inputs in declaration order (ptr for tensors, + # value for scalars), then extras, then output pointers. + arg_list: list = [] + for decl, val in zip(input_specs, args): + if isinstance(decl, ScalarInput): + arg_list.append(_coerce_scalar(val, decl.py_type)) + else: + arg_list.append(val.data_ptr()) + arg_list.extend(extra_vals) + arg_list.extend(t.data_ptr() for t in outs) + + stream = device.create_stream(_PTStream()) + cuda_launch(stream, LaunchConfig(grid=grid, block=block), kernel_obj, *arg_list) + return outs[0] if len(outs) == 1 else tuple(outs) + + return _make_positional_fn(_eager, input_specs) + + +def _coerce_scalar(value, py_type): + """Convert a Python scalar to the ctypes type that cuda.core needs to + forward it by value to the kernel. + """ + import ctypes + + if py_type is float: + return ctypes.c_float(float(value)) + if py_type is int: + return ctypes.c_int32(int(value)) + if py_type is bool: + return ctypes.c_bool(bool(value)) + raise TypeError(f"Unsupported ScalarInput.py_type: {py_type!r}") + + +def _make_aot_fn(spec: KernelSpec) -> Callable: + tensor_input_decls = _tensor_input_decls(spec.inputs) + extras = list(spec.extras) + + def _aot(inputs, outputs, tactic): + # TRT plugin inputs are the tensor-typed args only. The trtp layer + # slots them in by (tensor) arg order, so our tensor_input_decls list + # aligns with inputs positionally. + if isinstance(spec.geometry, Custom): + return spec.geometry.fn(inputs, outputs, tactic) + params = _compute_aot_launch(spec.geometry, inputs, outputs) + input_desc_by_name = {d.name: td for d, td in zip(tensor_input_decls, inputs)} + extra_exprs = _pack_extras_aot(extras, input_desc_by_name) + return params, extra_exprs + + return _aot + + +# ============================================================================= +# Spec validation (decorator-time, fail-fast with actionable messages) +# ============================================================================= + + +def _validate_spec(spec: KernelSpec) -> None: + """Validate a :class:`KernelSpec` before any compilation or registration. + + Catches the common authoring mistakes that would otherwise surface as + late KeyErrors or silent miscomputation at launch time. + """ + if not spec.inputs: + raise ValueError("KernelSpec.inputs must contain at least one InputDecl") + if not spec.outputs: + raise ValueError("KernelSpec.outputs must contain at least one OutputDecl") + + tensor_decls = _tensor_input_decls(spec.inputs) + if not tensor_decls: + raise ValueError( + "KernelSpec.inputs must contain at least one tensor InputDecl; " + "ScalarInput alone is insufficient (shape inference needs a tensor)." + ) + + input_names = {d.name for d in spec.inputs} + if len(input_names) != len(spec.inputs): + raise ValueError("KernelSpec.inputs contains duplicate names") + + # Scalar inputs must have a supported py_type. + for d in spec.inputs: + if isinstance(d, ScalarInput) and d.py_type not in (float, int, bool): + raise ValueError( + f"ScalarInput {d.name!r}: py_type must be float, int, or bool, " + f"got {d.py_type!r}" + ) + + tensor_input_names = {d.name for d in tensor_decls} + output_names = {d.name for d in spec.outputs} + if len(output_names) != len(spec.outputs): + raise ValueError("KernelSpec.outputs contains duplicate names") + + # SameAs/ReduceDims index into the *tensor* input list, so shape decls + # can't reference scalar positions — constrain against that count. + for decl in spec.outputs: + rel = decl.shape + if isinstance(rel, (SameAs, ReduceDims)): + idx = rel.input_idx + if idx < 0 or idx >= len(tensor_decls): + raise ValueError( + f"OutputDecl {decl.name!r}: shape references input_idx={idx} " + f"but spec has {len(tensor_decls)} tensor inputs" + ) + if decl.dtype_from is not None and decl.dtype_from not in tensor_input_names: + raise ValueError( + f"OutputDecl {decl.name!r}: dtype_from={decl.dtype_from!r} is not " + f"a tensor input name. Known tensor inputs: {sorted(tensor_input_names)}" + ) + + # Extras (Numel / DimSize) only make sense against tensor inputs. + for e in spec.extras: + if e.input_name not in tensor_input_names: + raise ValueError( + f"ExtraArg {type(e).__name__}({e.input_name!r}) references unknown " + f"tensor input. Known tensor inputs: {sorted(tensor_input_names)}" + ) + + geom = spec.geometry + if isinstance(geom, Elementwise): + if not geom.block or any(b <= 0 for b in geom.block): + raise ValueError( + f"Elementwise.block must be a non-empty tuple of positive ints, " + f"got {geom.block!r}" + ) + if geom.layout not in ("flat", "nd"): + raise ValueError( + f"Elementwise.layout must be 'flat' or 'nd', got {geom.layout!r}" + ) + if geom.layout == "flat" and len(geom.block) != 1: + raise ValueError( + f"Elementwise(layout='flat') requires block of length 1, got " + f"block={geom.block!r}" + ) + if len(geom.block) > 3: + raise ValueError( + f"Elementwise.block can have at most 3 dims, got {geom.block!r}" + ) + elif isinstance(geom, Reduction): + if geom.block_size <= 0: + raise ValueError( + f"Reduction.block_size must be > 0, got {geom.block_size}" + ) + if not geom.reduce_dims: + raise ValueError("Reduction.reduce_dims must be non-empty") + elif isinstance(geom, Custom): + if not callable(geom.fn): + raise ValueError("Custom.fn must be callable") + else: + raise TypeError(f"Unsupported Geometry: {geom!r}") + + +# ============================================================================= +# Public entry point +# ============================================================================= + + +def auto_cuda_kernel_plugin( + op_name: str, + spec: KernelSpec, + *, + supports_dynamic_shapes: bool = True, + requires_output_allocator: bool = False, + priority: ConverterPriority = ConverterPriority.STANDARD, + capability_validator: Optional[Callable] = None, +) -> None: + """Register a CUDA kernel described by a :class:`KernelSpec` end-to-end. + + Unlike :func:`custom_plugin` / :func:`manual_cuda_kernel_plugin`, this API auto-derives + the meta fn, eager fn, aot fn, and PyTorch schema from the declarative + ``spec`` — the caller does **not** provide any of them. + + The kernel must follow the calling convention + ``(input_ptrs..., extras..., output_ptrs...)``. + """ + if not ENABLED_FEATURES.qdp_plugin: + raise RuntimeError( + "TensorRT QDP plugins are not available. " + "Requires TensorRT >= 10.7.0 (and not 10.14.x)." + ) + + # Late import to avoid circular imports and keep the decorator cheap. + from torch_tensorrt.annotation._custom_plugin._descriptor import ( + register_cuda_python_plugin, + ) + + _validate_spec(spec) + + ptx, device, kernel_obj = _compile_kernel(spec) + + meta_fn = _make_meta_fn(spec) + eager_fn = _make_eager_fn(spec, kernel_obj, device) + aot_fn = _make_aot_fn(spec) + schema = _build_schema(spec) + + # Reuse register_cuda_python_plugin's existing flow: it will recompile the + # source (expected), register the PyTorch op, register the TRT AOT impl + # with the PTX, and register the converter. + cuda_spec = CudaPythonSpec( + kernel_source=spec.kernel_source, + kernel_name=spec.kernel_name, + aot_fn=aot_fn, + eager_fn=eager_fn, + include_paths=( + list(spec.include_paths) + if spec.include_paths is not None + else _default_cuda_include_paths() + ), + compile_std=spec.compile_std, + arch_override=spec.arch_override, + ) + + register_cuda_python_plugin( + op_name=op_name, + spec=cuda_spec, + meta_fn=meta_fn, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + priority=priority, + capability_validator=capability_validator, + register_torch_op=True, + schema=schema, + precompiled_ptx=ptx, + ) + _LOGGER.info("auto_cuda_kernel_plugin '%s' registered (schema: %s)", op_name, schema) diff --git a/py/torch_tensorrt/annotation/_kernel_spec.py b/py/torch_tensorrt/annotation/_kernel_spec.py new file mode 100644 index 0000000000..94c6bc19fb --- /dev/null +++ b/py/torch_tensorrt/annotation/_kernel_spec.py @@ -0,0 +1,165 @@ +"""Declarative kernel descriptor used by :func:`torch_tensorrt.annotation.auto_cuda_kernel_plugin`. + +This module intentionally contains no runtime logic beyond dataclass +construction. Derivation of meta / eager / aot / schema happens in +``_kernel_plugin.py``. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, List, Literal, Optional, Sequence, Tuple, Union + +import torch + + +# ---------- ShapeRel: how to derive output shape from inputs ---------- + + +@dataclass(frozen=True) +class SameAs: + """Output has the same shape as ``inputs[input_idx]``.""" + + input_idx: int = 0 + + +@dataclass(frozen=True) +class ReduceDims: + """Output = ``inputs[input_idx]`` with ``dims`` removed. + + If ``keepdim=True`` those axes are kept with size 1 instead of removed. + Negative axes are allowed. + """ + + input_idx: int + dims: Tuple[int, ...] + keepdim: bool = False + + +ShapeRel = Union[SameAs, ReduceDims] + + +# ---------- Extra scalar args (between input ptrs and output ptrs) ---------- + + +@dataclass(frozen=True) +class Numel: + """Pass ``inputs[input_name].numel()`` as an ``int`` extra.""" + + input_name: str + + +@dataclass(frozen=True) +class DimSize: + """Pass ``inputs[input_name].shape[axis]`` as an ``int`` extra. + + Negative ``axis`` allowed. + """ + + input_name: str + axis: int + + +ExtraArg = Union[Numel, DimSize] + + +# ---------- Launch geometry ---------- + + +@dataclass(frozen=True) +class Elementwise: + """One thread per output element. + + ``layout="flat"``: 1D launch over the flattened output ``numel``. + ``block = (bx,)`` → ``grid = (cdiv(numel(out), bx),)``. + ``layout="nd"``: the trailing ``len(block)`` axes of the output are + block-parallelized; any leading axes are folded into ``grid_z``. + ``block[0]`` maps to the last (innermost) axis, matching CUDA's + convention that ``grid_x`` / ``block_x`` varies fastest. + """ + + block: Tuple[int, ...] = (256,) + layout: Literal["flat", "nd"] = "flat" + + +@dataclass(frozen=True) +class Reduction: + """One block per output element; block threads cooperate across the + reduction axes. ``reduce_dims`` are axes of the **input** (not output) + that are collapsed. Grid = ``numel(output)``, block = ``block_size``. + """ + + reduce_dims: Tuple[int, ...] + block_size: int = 256 + + +@dataclass(frozen=True) +class Custom: + """Escape hatch. ``fn(inputs, outputs, tactic)`` returns the same shape + as today's hand-written aot_fn: ``(KernelLaunchParams, SymExprs)``. + """ + + fn: Callable + + +Geometry = Union[Elementwise, Reduction, Custom] + + +# ---------- Input / output decls ---------- + + +@dataclass(frozen=True) +class InputDecl: + """Tensor kernel input. + + The corresponding kernel argument is a ``T*`` (data pointer) at the input + pointer position in the calling convention. + """ + + name: str + dtype: Optional[torch.dtype] = None + + +@dataclass(frozen=True) +class ScalarInput: + """Scalar (non-tensor) kernel input — e.g. ``float alpha`` or ``int k``. + + Scalars are forwarded by value to the kernel at the input position + (after all preceding tensor/scalar inputs, before extras and output + pointers). ``py_type`` must be ``float``, ``int``, or ``bool``. + """ + + name: str + py_type: type # float, int, bool + + +InputSpec = Union[InputDecl, ScalarInput] + + +@dataclass(frozen=True) +class OutputDecl: + name: str + shape: ShapeRel + dtype_from: Optional[str] = None + + +# ---------- Top-level spec ---------- + + +@dataclass +class KernelSpec: + """Fully declarative description of a CUDA kernel. + + Kernel signature convention (required): the ``__global__`` function + receives arguments in this fixed order — input pointers, then extras in + ``extras`` order, then output pointers. + """ + + kernel_source: str + kernel_name: str + inputs: Sequence[InputSpec] + outputs: Sequence[OutputDecl] + extras: Sequence[ExtraArg] + geometry: Geometry + include_paths: Optional[List[str]] = None + compile_std: str = "c++17" + arch_override: Optional[str] = None diff --git a/py/torch_tensorrt/annotation/_specs.py b/py/torch_tensorrt/annotation/_specs.py new file mode 100644 index 0000000000..7d00ba1f26 --- /dev/null +++ b/py/torch_tensorrt/annotation/_specs.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Callable, List, Optional + + +def _default_cuda_include_paths() -> List[str]: + """Resolve CUDA include dir from CUDA_HOME / CUDA_PATH, else default.""" + for env_var in ("CUDA_HOME", "CUDA_PATH"): + root = os.environ.get(env_var) + if root: + return [os.path.join(root, "include")] + return ["/usr/local/cuda/include"] + + +@dataclass +class CudaPythonSpec: + """Specification for a CUDA C++ kernel compiled via NVRTC (cuda-python). + + Create instances via :func:`cuda_python` rather than constructing directly. + """ + + kernel_source: str + kernel_name: str + aot_fn: Optional[Callable] + eager_fn: Optional[Callable] = None + include_paths: List[str] = field(default_factory=_default_cuda_include_paths) + compile_std: str = "c++17" + arch_override: Optional[str] = None diff --git a/pyproject.toml b/pyproject.toml index 47d18ed8fe..a2731226d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,10 @@ docs = [ quantization = ["nvidia-modelopt[all]>=0.27.1"] +# Optional runtime deps for the torch_tensorrt.annotation QDP-plugin layer, +# which compiles user-supplied CUDA C++ kernels via NVRTC. +annotation = ["cuda-python"] + [project.urls] Homepage = "https://pytorch.org/tensorrt" Documentation = "https://pytorch.org/tensorrt" diff --git a/setup.py b/setup.py index dba7d5ec6e..808eb7b60b 100644 --- a/setup.py +++ b/setup.py @@ -472,6 +472,8 @@ def run(self): dynamo_packages = [ "torch_tensorrt", + "torch_tensorrt.annotation", + "torch_tensorrt.annotation._custom_plugin", "torch_tensorrt.dynamo", "torch_tensorrt.dynamo.backend", "torch_tensorrt.dynamo.conversion", @@ -506,6 +508,8 @@ def run(self): dynamo_package_dir = { "torch_tensorrt": "py/torch_tensorrt", + "torch_tensorrt.annotation": "py/torch_tensorrt/annotation", + "torch_tensorrt.annotation._custom_plugin": "py/torch_tensorrt/annotation/_custom_plugin", "torch_tensorrt.dynamo": "py/torch_tensorrt/dynamo", "torch_tensorrt.dynamo.backend": "py/torch_tensorrt/dynamo/backend", "torch_tensorrt.dynamo.conversion": "py/torch_tensorrt/dynamo/conversion", diff --git a/tests/py/annotation/__init__.py b/tests/py/annotation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/py/annotation/conftest.py b/tests/py/annotation/conftest.py new file mode 100644 index 0000000000..11e589f754 --- /dev/null +++ b/tests/py/annotation/conftest.py @@ -0,0 +1,144 @@ +"""Shared CUDA kernel sources, skip marks, and helpers for annotation tests.""" +from __future__ import annotations + +import pytest +import torch +import torch_tensorrt + + +skip_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA device required" +) +skip_no_qdp = pytest.mark.skipif( + not torch_tensorrt.ENABLED_FEATURES.qdp_plugin, + reason="TensorRT QDP plugin not available", +) + + +SIGMOID_SRC = """ +extern "C" __global__ void tta_test_sigmoid( + const float* x, int n, float* y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = 1.0f / (1.0f + __expf(-x[i])); +} +""" + +RELU_FLAT_SRC = """ +extern "C" __global__ void tta_kp_relu_flat( + const float* x, int n, float* y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = x[i] > 0.f ? x[i] : 0.f; +} +""" + +RELU_ND_SRC = """ +extern "C" __global__ void tta_kp_relu_nd( + const float* x, int H, int W, float* y) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= H || j >= W) return; + float v = x[i * W + j]; + y[i * W + j] = v > 0.f ? v : 0.f; +} +""" + +ROW_SUM_SRC = """ +extern "C" __global__ void tta_kp_row_sum( + const float* x, int D, float* y) { + int row = blockIdx.x; + const float* xr = x + row * D; + float s = 0.f; + for (int j = threadIdx.x; j < D; j += blockDim.x) s += xr[j]; + __shared__ float sbuf[256]; + sbuf[threadIdx.x] = s; + __syncthreads(); + for (int step = blockDim.x >> 1; step > 0; step >>= 1) { + if (threadIdx.x < step) sbuf[threadIdx.x] += sbuf[threadIdx.x + step]; + __syncthreads(); + } + if (threadIdx.x == 0) y[row] = sbuf[0]; +} +""" + +ADD_SRC = """ +extern "C" __global__ void tta_kp_add( + const float* a, const float* b, int n, float* c) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) c[i] = a[i] + b[i]; +} +""" + +SCALE_SRC = """ +extern "C" __global__ void tta_kp_scale( + const float* x, float alpha, int n, float* y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = alpha * x[i]; +} +""" + + +def register_once(register_fn): + """Invoke `register_fn()`; swallow duplicate-registration errors on re-run.""" + try: + register_fn() + except Exception: + pass + + +def pointwise_aot(block_size: int = 256): + """Build a minimal trtp aot_fn for 1-D pointwise kernels.""" + import tensorrt.plugin as trtp + + def _aot(inputs, outputs, tactic): + n = inputs[0].shape_expr.numel() + p = trtp.KernelLaunchParams() + p.grid_x, p.block_x, p.shared_mem = trtp.cdiv(n, block_size), block_size, 0 + extra = trtp.SymIntExprs(1) + extra[0] = trtp.SymInt32(n) + return p, extra + + return _aot + + +def make_eager_sigmoid(): + """Compile SIGMOID_SRC once and return an eager launch fn.""" + try: + from cuda.core import Device, LaunchConfig, Program, ProgramOptions, launch + except ImportError: + from cuda.core.experimental import ( + Device, + LaunchConfig, + Program, + ProgramOptions, + launch, + ) + + dev = Device() + dev.set_current() + opts = ProgramOptions( + std="c++17", arch=f"sm_{dev.arch}", include_path=["/usr/local/cuda/include"] + ) + kernel = ( + Program(SIGMOID_SRC, code_type="c++", options=opts) + .compile("ptx", name_expressions=("tta_test_sigmoid",)) + .get_kernel("tta_test_sigmoid") + ) + + class _Stream: + def __cuda_stream__(self): + return (0, torch.cuda.current_stream().cuda_stream) + + def _eager(x: torch.Tensor) -> torch.Tensor: + y = torch.empty_like(x) + n = int(x.numel()) + launch( + dev.create_stream(_Stream()), + LaunchConfig(grid=(max(1, (n + 255) // 256),), block=(256,)), + kernel, + x.data_ptr(), + n, + y.data_ptr(), + ) + return y + + return _eager diff --git a/tests/py/annotation/test_auto_cuda_kernel_plugin.py b/tests/py/annotation/test_auto_cuda_kernel_plugin.py new file mode 100644 index 0000000000..2b67c8b38f --- /dev/null +++ b/tests/py/annotation/test_auto_cuda_kernel_plugin.py @@ -0,0 +1,372 @@ +"""End-to-end tests for torch_tensorrt.annotation.auto_cuda_kernel_plugin.""" +import pytest +import torch +import torch_tensorrt +import torch_tensorrt.annotation as tta +from torch_tensorrt.annotation._kernel_plugin import ( + _torch_output_shape_dtype, + _validate_spec, +) + +from .conftest import ( + ADD_SRC, + RELU_FLAT_SRC, + RELU_ND_SRC, + ROW_SUM_SRC, + SCALE_SRC, + register_once, + skip_no_cuda, + skip_no_qdp, +) + + +# ---- No-GPU: validation & shape inference ---- + + +def _base_spec_kwargs(**overrides): + kwargs = dict( + kernel_source="// no-op", + kernel_name="k", + inputs=[tta.InputDecl("x")], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[], + geometry=tta.Elementwise(block=(256,), layout="flat"), + ) + kwargs.update(overrides) + return kwargs + + +@pytest.mark.parametrize( + "overrides, match", + [ + ({"extras": [tta.Numel("y")]}, "references unknown tensor input"), + ( + {"outputs": [tta.OutputDecl("y", shape=tta.ReduceDims(5, (-1,)))]}, + "input_idx=5", + ), + ( + { + "outputs": [ + tta.OutputDecl("y", shape=tta.SameAs(0), dtype_from="not_an_input") + ] + }, + "dtype_from=", + ), + ({"geometry": tta.Elementwise(block=(16, 16), layout="flat")}, "layout='flat'"), + ({"inputs": [tta.InputDecl("x"), tta.InputDecl("x")]}, "duplicate names"), + ({"geometry": tta.Reduction(reduce_dims=(), block_size=256)}, "reduce_dims"), + ], +) +def test_validate_spec_error_paths(overrides, match): + spec = tta.KernelSpec(**_base_spec_kwargs(**overrides)) + with pytest.raises(ValueError, match=match): + _validate_spec(spec) + + +@pytest.mark.parametrize( + "shape_rel, want", + [ + (tta.SameAs(0), (2, 3, 4)), + (tta.ReduceDims(0, (-1,)), (2, 3)), + (tta.ReduceDims(0, (1,), keepdim=True), (2, 1, 4)), + ], +) +def test_shape_inference(shape_rel, want): + x = torch.empty(2, 3, 4) + shape, _ = _torch_output_shape_dtype( + tta.OutputDecl("y", shape=shape_rel), [x], [tta.InputDecl("x")] + ) + assert shape == want + + +# ---- GPU: one geometry per class, eager + TRT compile ---- + + +def _register_relu_flat(): + tta.auto_cuda_kernel_plugin( + "tta_kp::relu_flat", + tta.KernelSpec( + kernel_source=RELU_FLAT_SRC, + kernel_name="tta_kp_relu_flat", + inputs=[tta.InputDecl("x")], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[tta.Numel("x")], + geometry=tta.Elementwise(block=(256,), layout="flat"), + ), + ) + + +@skip_no_cuda +@skip_no_qdp +class TestFlat: + def test_eager_any_rank(self): + register_once(_register_relu_flat) + x = torch.randn(2, 3, 5, 7, device="cuda") + assert torch.allclose(torch.ops.tta_kp.relu_flat(x), torch.relu(x), atol=1e-5) + + def test_trt_compile(self): + register_once(_register_relu_flat) + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.tta_kp.relu_flat(x) + + x = torch.randn(1, 4, 8, 8, device="cuda") + trt = torch_tensorrt.compile( + M().cuda().eval(), + inputs=[x], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + with torch.no_grad(): + assert torch.allclose(trt(x), torch.relu(x), atol=1e-2, rtol=1e-2) + + +def _register_relu_nd(): + tta.auto_cuda_kernel_plugin( + "tta_kp::relu_nd", + tta.KernelSpec( + kernel_source=RELU_ND_SRC, + kernel_name="tta_kp_relu_nd", + inputs=[tta.InputDecl("x")], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[tta.DimSize("x", 0), tta.DimSize("x", 1)], + geometry=tta.Elementwise(block=(16, 16), layout="nd"), + ), + ) + + +@skip_no_cuda +@skip_no_qdp +class TestND: + def test_eager_2d(self): + register_once(_register_relu_nd) + x = torch.randn(33, 47, device="cuda") + assert torch.allclose(torch.ops.tta_kp.relu_nd(x), torch.relu(x), atol=1e-5) + + def test_trt_compile(self): + register_once(_register_relu_nd) + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.tta_kp.relu_nd(x) + + x = torch.randn(33, 47, device="cuda") + trt = torch_tensorrt.compile( + M().cuda().eval(), + inputs=[x], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + with torch.no_grad(): + assert torch.allclose(trt(x), torch.relu(x), atol=1e-2, rtol=1e-2) + + def test_nd_block_mismatch_raises(self): + # Register a 1-D kernel with a 2-D ND geometry — must raise at launch. + src_1d = """ + extern "C" __global__ void tta_kp_relu_bad_nd( + const float* x, int n, float* y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = x[i] > 0.f ? x[i] : 0.f; + } + """ + tta.auto_cuda_kernel_plugin( + "tta_kp::relu_bad_nd", + tta.KernelSpec( + kernel_source=src_1d, + kernel_name="tta_kp_relu_bad_nd", + inputs=[tta.InputDecl("x")], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[tta.Numel("x")], + geometry=tta.Elementwise(block=(16, 16), layout="nd"), + ), + ) + x = torch.randn(32, device="cuda") + with pytest.raises(ValueError, match="ndim >= 2"): + torch.ops.tta_kp.relu_bad_nd(x) + + +def _register_row_sum(op_name: str = "tta_kp::row_sum", *, keepdim: bool = False): + name = op_name.split("::")[-1] + tta.auto_cuda_kernel_plugin( + op_name, + tta.KernelSpec( + kernel_source=ROW_SUM_SRC.replace("tta_kp_row_sum", f"tta_kp_{name}"), + kernel_name=f"tta_kp_{name}", + inputs=[tta.InputDecl("x")], + outputs=[ + tta.OutputDecl("y", shape=tta.ReduceDims(0, (-1,), keepdim=keepdim)) + ], + extras=[tta.DimSize("x", -1)], + geometry=tta.Reduction(reduce_dims=(-1,), block_size=256), + ), + ) + + +@skip_no_cuda +@skip_no_qdp +class TestReduction: + def test_eager_any_rank(self): + register_once(_register_row_sum) + for shape in [(4, 128), (2, 3, 64)]: + x = torch.randn(*shape, device="cuda") + assert torch.allclose( + torch.ops.tta_kp.row_sum(x), x.sum(-1), atol=1e-3, rtol=1e-3 + ) + + def test_trt_compile(self): + register_once(_register_row_sum) + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.tta_kp.row_sum(x) + + x = torch.randn(2, 128, device="cuda") + trt = torch_tensorrt.compile( + M().cuda().eval(), + inputs=[x], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + with torch.no_grad(): + assert torch.allclose(trt(x), x.sum(-1), atol=1e-2, rtol=1e-2) + + def test_keepdim_shape(self): + register_once( + lambda: _register_row_sum("tta_kp::row_sum_keepdim", keepdim=True) + ) + x = torch.randn(4, 64, device="cuda") + out = torch.ops.tta_kp.row_sum_keepdim(x) + assert out.shape == (4, 1) + assert torch.allclose(out, x.sum(-1, keepdim=True), atol=1e-3, rtol=1e-3) + + +# ---- GPU: multi-input, scalar input, custom geometry ---- + + +@skip_no_cuda +@skip_no_qdp +def test_multi_input_add(): + tta.auto_cuda_kernel_plugin( + "tta_kp::add", + tta.KernelSpec( + kernel_source=ADD_SRC, + kernel_name="tta_kp_add", + inputs=[tta.InputDecl("a"), tta.InputDecl("b")], + outputs=[tta.OutputDecl("c", shape=tta.SameAs(0))], + extras=[tta.Numel("a")], + geometry=tta.Elementwise(block=(256,), layout="flat"), + ), + ) + a = torch.randn(256, device="cuda") + b = torch.randn(256, device="cuda") + assert torch.allclose(torch.ops.tta_kp.add(a, b), a + b, atol=1e-5) + + +def _register_scale(): + tta.auto_cuda_kernel_plugin( + "tta_kp::scale", + tta.KernelSpec( + kernel_source=SCALE_SRC, + kernel_name="tta_kp_scale", + inputs=[tta.InputDecl("x"), tta.ScalarInput("alpha", float)], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[tta.Numel("x")], + geometry=tta.Elementwise(block=(256,), layout="flat"), + ), + supports_dynamic_shapes=True, + ) + + +@skip_no_cuda +@skip_no_qdp +class TestScalarInput: + def test_schema_has_float(self): + register_once(_register_scale) + schemas = torch._C._jit_get_schemas_for_operator("tta_kp::scale") + assert any("float alpha" in str(s) for s in schemas) + + def test_eager_run(self): + register_once(_register_scale) + x = torch.randn(256, device="cuda") + assert torch.allclose(torch.ops.tta_kp.scale(x, 2.5), 2.5 * x, atol=1e-5) + + @pytest.mark.xfail( + strict=True, + reason=( + "torch_tensorrt+TRT plugin machinery coerces scalar attrs via " + "attr_type_annot(numpy_array), which errors on non-0-D arrays; " + "auto_cuda_kernel_plugin scalar inputs themselves work in eager + schema." + ), + ) + def test_trt_compile(self): + register_once(_register_scale) + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.tta_kp.scale(x, 3.0) + + x = torch.randn(128, device="cuda") + trt = torch_tensorrt.compile( + M().cuda().eval(), + inputs=[x], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + with torch.no_grad(): + assert torch.allclose(trt(x), 3.0 * x, atol=1e-2, rtol=1e-2) + + +@skip_no_cuda +@skip_no_qdp +def test_custom_geometry_aot_path_only(): + """Custom geometry routes to user's aot_fn; eager must refuse.""" + import tensorrt.plugin as trtp + + captured = {} + + def _aot(inputs, outputs, tactic): + captured["called"] = True + n = inputs[0].shape_expr.numel() + p = trtp.KernelLaunchParams() + p.grid_x, p.block_x, p.shared_mem = trtp.cdiv(n, 256), 256, 0 + extra = trtp.SymIntExprs(1) + extra[0] = trtp.SymInt32(n) + return p, extra + + src = """ + extern "C" __global__ void tta_kp_cust_relu( + const float* x, int n, float* y) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) y[i] = x[i] > 0.f ? x[i] : 0.f; + } + """ + tta.auto_cuda_kernel_plugin( + "tta_kp::custom_geo_relu", + tta.KernelSpec( + kernel_source=src, + kernel_name="tta_kp_cust_relu", + inputs=[tta.InputDecl("x")], + outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))], + extras=[tta.Numel("x")], + geometry=tta.Custom(fn=_aot), + ), + ) + + x = torch.randn(512, device="cuda") + with pytest.raises(RuntimeError, match="Custom geometry has no eager"): + torch.ops.tta_kp.custom_geo_relu(x) + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.tta_kp.custom_geo_relu(x) + + trt = torch_tensorrt.compile( + M().cuda().eval(), + inputs=[x], + enabled_precisions={torch.float32}, + min_block_size=1, + ) + with torch.no_grad(): + assert torch.allclose(trt(x), torch.relu(x), atol=1e-2, rtol=1e-2) + assert captured.get("called") is True diff --git a/tests/py/annotation/test_manual_cuda_kernel_plugin.py b/tests/py/annotation/test_manual_cuda_kernel_plugin.py new file mode 100644 index 0000000000..d92912365a --- /dev/null +++ b/tests/py/annotation/test_manual_cuda_kernel_plugin.py @@ -0,0 +1,302 @@ +"""Tests for manual_cuda_kernel_plugin and its lower-level building blocks.""" +import pytest +import torch +import torch_tensorrt +import torch_tensorrt.annotation as tta +from torch_tensorrt.annotation._custom_plugin._descriptor import _infer_schema + +from .conftest import ( + SIGMOID_SRC, + make_eager_sigmoid, + pointwise_aot, + skip_no_cuda, + skip_no_qdp, +) + + +# ---- No-GPU: CudaPythonSpec construction ---- + + +class TestCudaPythonSpec: + def test_basic_construction(self): + spec = tta.cuda_python("// src", "my_k", aot_fn=lambda *a: None) + assert spec.kernel_source == "// src" + assert spec.kernel_name == "my_k" + assert spec.compile_std == "c++17" + assert "/usr/local/cuda/include" in spec.include_paths + + def test_overrides(self): + spec = tta.cuda_python( + "// s", "k", + include_paths=["/opt/cuda/include"], + arch_override="sm_90", + ) + assert spec.include_paths == ["/opt/cuda/include"] + assert spec.arch_override == "sm_90" + + def test_aot_fn_settable_post_construction(self): + spec = tta.cuda_python("// s", "k") + assert spec.aot_fn is None + spec.aot_fn = lambda *a: None + assert spec.aot_fn is not None + + +# ---- No-GPU: pointwise factories ---- + + +@pytest.mark.parametrize( + "kwargs, match", + [(dict(block_size=0), "block_size"), (dict(input_index=-1), "input_index")], +) +def test_pointwise_aot_factory_validation(kwargs, match): + with pytest.raises(ValueError, match=match): + tta.pointwise_aot(**kwargs) + assert callable(tta.pointwise_aot()) + + +def test_pointwise_eager_factory_validation(): + with pytest.raises(ValueError, match="block_size"): + tta.pointwise_eager("// src", "k", block_size=0) + assert callable(tta.pointwise_eager("// src", "k")) + + +# ---- No-GPU: schema inference (small defs — needs real __annotations__) ---- + + +def test_schema_single_tensor(): + def meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + s = _infer_schema(meta) + assert "Tensor x" in s and "-> Tensor" in s + + +def test_schema_two_tensors(): + def meta(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + s = _infer_schema(meta) + assert "Tensor a" in s and "Tensor b" in s + + +def test_schema_mixed_scalar(): + def meta(x: torch.Tensor, scale: float) -> torch.Tensor: + return torch.empty_like(x) + + s = _infer_schema(meta) + assert "Tensor x" in s and "float scale" in s + + +# ---- No-GPU: decorator plumbing ---- + + +def test_custom_plugin_forwards_explicit_schema(monkeypatch): + from torch_tensorrt.annotation._custom_plugin import _descriptor + + captured = {} + monkeypatch.setattr( + _descriptor, + "register_cuda_python_plugin", + lambda *a, **k: captured.update(k), + ) + + spec = tta.cuda_python("// s", "k", aot_fn=lambda *a: None) + + @tta.custom_plugin("tta_test::schema_forward", spec, schema="(Tensor x) -> Tensor") + def _meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + assert captured["schema"] == "(Tensor x) -> Tensor" + assert captured["register_torch_op"] is True + + +def test_manual_cuda_kernel_plugin_one_shot(monkeypatch): + from torch_tensorrt.annotation._custom_plugin import _descriptor + + captured = {} + monkeypatch.setattr( + _descriptor, + "register_cuda_python_plugin", + lambda *a, **k: captured.update(k), + ) + + @tta.manual_cuda_kernel_plugin( + op_name="tta_test::one_shot", + kernel_source="// s", + kernel_name="k", + aot_fn=lambda *a: None, + eager_fn=lambda x: x, + schema="(Tensor x) -> Tensor", + supports_dynamic_shapes=True, + ) + def _meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + assert captured["op_name"] == "tta_test::one_shot" + assert captured["schema"] == "(Tensor x) -> Tensor" + assert captured["supports_dynamic_shapes"] is True + + +def test_precompiled_ptx_skips_nvrtc(monkeypatch): + """register_cuda_python_plugin(precompiled_ptx=...) must not call compile_to_ptx.""" + from torch_tensorrt.annotation._custom_plugin import _descriptor, _nvrtc + from torch_tensorrt.annotation._specs import CudaPythonSpec + + def _fail(*a, **k): + raise AssertionError( + "compile_to_ptx must NOT run when precompiled_ptx is provided" + ) + + monkeypatch.setattr(_nvrtc, "compile_to_ptx", _fail) + for name in ( + "_register_pytorch_op", + "generate_plugin", + "_register_aot_impl", + "generate_plugin_converter", + ): + monkeypatch.setattr(_descriptor, name, lambda *a, **k: None) + + spec = CudaPythonSpec( + kernel_source="// ignored", + kernel_name="k", + aot_fn=lambda *a: None, + eager_fn=lambda x: x, + ) + + def _meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + _descriptor.register_cuda_python_plugin( + op_name="tta_test::ptx_reused", + spec=spec, + meta_fn=_meta, + precompiled_ptx=b"fake-ptx", + ) + + +# ---- GPU: NVRTC compilation ---- + + +@skip_no_cuda +@skip_no_qdp +class TestNVRTC: + def test_compiles_to_ptx(self): + from torch_tensorrt.annotation._custom_plugin._nvrtc import compile_to_ptx + + ptx, _, _ = compile_to_ptx( + SIGMOID_SRC, "tta_test_sigmoid", ["/usr/local/cuda/include"] + ) + assert isinstance(ptx, bytes) and b"tta_test_sigmoid" in ptx + + def test_invalid_source_raises(self): + from torch_tensorrt.annotation._custom_plugin._nvrtc import compile_to_ptx + + with pytest.raises(Exception): + compile_to_ptx( + "this is not valid CUDA !!!###", "bad", ["/usr/local/cuda/include"] + ) + + def test_arch_override_respected(self): + from torch_tensorrt.annotation._custom_plugin._nvrtc import compile_to_ptx + + arch = f"sm_{torch.cuda.get_device_capability()[0]}0" + ptx, _, _ = compile_to_ptx( + SIGMOID_SRC, + "tta_test_sigmoid", + ["/usr/local/cuda/include"], + arch_override=arch, + ) + assert isinstance(ptx, bytes) + + +# ---- GPU: integration — register, eager, TRT compile w/ dynamic shapes ---- + + +def _register_sigmoid(op_name: str): + spec = tta.cuda_python( + SIGMOID_SRC, + "tta_test_sigmoid", + aot_fn=pointwise_aot(), + eager_fn=make_eager_sigmoid(), + ) + + @tta.custom_plugin(op_name, spec, supports_dynamic_shapes=True) + def _meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + +@skip_no_cuda +@skip_no_qdp +class TestIntegration: + def test_register_and_eager(self): + try: + _register_sigmoid("tta_test::sigmoid_eager") + except Exception: + pass + x = torch.randn(1024, device="cuda") + assert torch.allclose( + torch.ops.tta_test.sigmoid_eager(x), torch.sigmoid(x), atol=1e-4, rtol=1e-4 + ) + + def test_trt_compile_dynamic_shapes(self): + try: + _register_sigmoid("tta_test::sigmoid_dyn") + except Exception: + pass + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.tta_test.sigmoid_dyn(x) + + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 128), + opt_shape=(1, 512), + max_shape=(1, 2048), + dtype=torch.float32, + ) + ] + trt = torch_tensorrt.compile( + M().cuda().eval(), + inputs=inputs, + enabled_precisions={torch.float32}, + min_block_size=1, + ) + for size in [128, 512, 2048]: + x = torch.randn(1, size, device="cuda") + with torch.no_grad(): + assert torch.allclose( + trt(x), torch.sigmoid(x), atol=1e-2, rtol=1e-2 + ) + + +@skip_no_cuda +@skip_no_qdp +def test_schema_override_integration(): + """End-to-end: schema= overrides the inferred schema at real registration.""" + src = """ + extern "C" __global__ void schema_ov_noop( + const float* x, int n, float alpha, float* y) {} + """ + spec = tta.cuda_python( + src, "schema_ov_noop", + aot_fn=pointwise_aot(), + eager_fn=lambda x, alpha: alpha * x, # reference impl — doesn't touch the kernel + ) + + @tta.custom_plugin( + "tta_test::schema_ov", + spec, + supports_dynamic_shapes=True, + schema="(Tensor x, float alpha) -> Tensor", + ) + def _meta(x, alpha): # no hints — only schema= makes `float alpha` land + return torch.empty_like(x) + + schemas = [ + str(s) for s in torch._C._jit_get_schemas_for_operator("tta_test::schema_ov") + ] + assert any("float alpha" in s for s in schemas) + + x = torch.randn(32, device="cuda") + assert torch.allclose(torch.ops.tta_test.schema_ov(x, 2.5), 2.5 * x, atol=1e-5) From 4e4f65e41ca5bba4329c9d77b541d3fa5d5e5468 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 22 Apr 2026 18:04:48 +0000 Subject: [PATCH 2/2] linting --- .../auto_cuda_kernel_plugin_annotation.py | 6 +- .../manual_cuda_kernel_plugin_annotation.py | 11 ++- py/torch_tensorrt/annotation/__init__.py | 6 +- .../annotation/_custom_plugin/__init__.py | 40 ++++++----- .../annotation/_custom_plugin/_descriptor.py | 38 +++++----- .../annotation/_custom_plugin/_nvrtc.py | 4 +- .../annotation/_kernel_plugin.py | 71 +++++++++++-------- py/torch_tensorrt/annotation/_kernel_spec.py | 6 +- py/torch_tensorrt/annotation/_specs.py | 6 +- tests/py/annotation/conftest.py | 3 +- .../test_auto_cuda_kernel_plugin.py | 3 +- .../test_manual_cuda_kernel_plugin.py | 16 +++-- 12 files changed, 117 insertions(+), 93 deletions(-) diff --git a/examples/dynamo/auto_cuda_kernel_plugin_annotation.py b/examples/dynamo/auto_cuda_kernel_plugin_annotation.py index a3f4dd99d5..49624e1275 100644 --- a/examples/dynamo/auto_cuda_kernel_plugin_annotation.py +++ b/examples/dynamo/auto_cuda_kernel_plugin_annotation.py @@ -26,6 +26,7 @@ import sys import torch + import torch_tensorrt if not torch_tensorrt.ENABLED_FEATURES.qdp_plugin: @@ -58,7 +59,6 @@ import torch_tensorrt.annotation as tta - # Calling convention expected by auto_cuda_kernel_plugin: # (input_ptrs..., extras..., output_ptrs...) @@ -100,7 +100,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = Model().cuda().eval() eager_out = model(x) - print("Eager result matches torch.sigmoid:", torch.allclose(eager_out, ref, atol=1e-4)) + print( + "Eager result matches torch.sigmoid:", torch.allclose(eager_out, ref, atol=1e-4) + ) print("Compiling with Torch-TensorRT...") trt_model = torch_tensorrt.compile( diff --git a/examples/dynamo/manual_cuda_kernel_plugin_annotation.py b/examples/dynamo/manual_cuda_kernel_plugin_annotation.py index a14e8da1cf..f71b6aaba7 100644 --- a/examples/dynamo/manual_cuda_kernel_plugin_annotation.py +++ b/examples/dynamo/manual_cuda_kernel_plugin_annotation.py @@ -22,6 +22,7 @@ import sys import torch + import torch_tensorrt if not torch_tensorrt.ENABLED_FEATURES.qdp_plugin: @@ -34,7 +35,9 @@ try: import tensorrt.plugin as trtp except ImportError: - print("[manual_cuda_kernel_plugin_annotation] Skipping example: tensorrt.plugin unavailable.") + print( + "[manual_cuda_kernel_plugin_annotation] Skipping example: tensorrt.plugin unavailable." + ) sys.exit(0) try: @@ -59,7 +62,6 @@ import torch_tensorrt.annotation as tta - CU_REPEAT2 = """ extern "C" __global__ void repeat2_kernel( const float* __restrict__ x, const int n, float* __restrict__ y) { @@ -141,7 +143,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = Repeat2Model().cuda().eval() eager_out = model(x) - print("Eager result matches repeat_interleave:", torch.allclose(eager_out, ref, atol=1e-4)) + print( + "Eager result matches repeat_interleave:", + torch.allclose(eager_out, ref, atol=1e-4), + ) print("Compiling with Torch-TensorRT...") with torch_tensorrt.logging.debug(): diff --git a/py/torch_tensorrt/annotation/__init__.py b/py/torch_tensorrt/annotation/__init__.py index 31cd62a554..d6af8aafba 100644 --- a/py/torch_tensorrt/annotation/__init__.py +++ b/py/torch_tensorrt/annotation/__init__.py @@ -70,14 +70,14 @@ def forward(self, x): return torch.ops.myns.relu(x) to :func:`manual_cuda_kernel_plugin` and supply ``aot_fn`` / ``eager_fn`` directly. """ -from torch_tensorrt.annotation._specs import CudaPythonSpec from torch_tensorrt.annotation._custom_plugin import ( - manual_cuda_kernel_plugin, cuda_python, custom_plugin, + manual_cuda_kernel_plugin, pointwise_aot, pointwise_eager, ) +from torch_tensorrt.annotation._kernel_plugin import auto_cuda_kernel_plugin from torch_tensorrt.annotation._kernel_spec import ( Custom, DimSize, @@ -91,7 +91,7 @@ def forward(self, x): return torch.ops.myns.relu(x) SameAs, ScalarInput, ) -from torch_tensorrt.annotation._kernel_plugin import auto_cuda_kernel_plugin +from torch_tensorrt.annotation._specs import CudaPythonSpec __all__ = [ "CudaPythonSpec", diff --git a/py/torch_tensorrt/annotation/_custom_plugin/__init__.py b/py/torch_tensorrt/annotation/_custom_plugin/__init__.py index ad0104152f..6e2329de0a 100644 --- a/py/torch_tensorrt/annotation/_custom_plugin/__init__.py +++ b/py/torch_tensorrt/annotation/_custom_plugin/__init__.py @@ -1,11 +1,11 @@ from __future__ import annotations import logging -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority from torch_tensorrt.annotation._specs import CudaPythonSpec, _default_cuda_include_paths +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority _LOGGER = logging.getLogger(__name__) @@ -13,8 +13,8 @@ def cuda_python( kernel_source: str, kernel_name: str, - aot_fn: Optional[Callable] = None, - eager_fn: Optional[Callable] = None, + aot_fn: Optional[Callable[..., Any]] = None, + eager_fn: Optional[Callable[..., Any]] = None, include_paths: Optional[List[str]] = None, compile_std: str = "c++17", arch_override: Optional[str] = None, @@ -50,7 +50,11 @@ def cuda_python( kernel_name=kernel_name, aot_fn=aot_fn, eager_fn=eager_fn, - include_paths=include_paths if include_paths is not None else _default_cuda_include_paths(), + include_paths=( + include_paths + if include_paths is not None + else _default_cuda_include_paths() + ), compile_std=compile_std, arch_override=arch_override, ) @@ -62,9 +66,9 @@ def custom_plugin( supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, - capability_validator: Optional[Callable] = None, + capability_validator: Optional[Callable[..., Any]] = None, schema: Optional[str] = None, -) -> Callable: +) -> Callable[..., Any]: """Decorator that registers a CUDA kernel as a TensorRT QDP plugin. The decorated function acts as the **meta / fake implementation** (shape and @@ -125,7 +129,7 @@ def _(x: torch.Tensor) -> torch.Tensor: register_cuda_python_plugin, ) - def decorator(meta_fn: Callable) -> Callable: + def decorator(meta_fn: Callable[..., Any]) -> Callable[..., Any]: register_cuda_python_plugin( op_name=op_name, spec=spec, @@ -145,7 +149,7 @@ def decorator(meta_fn: Callable) -> Callable: def pointwise_aot( block_size: int = 256, input_index: int = 0, -) -> Callable: +) -> Callable[..., Any]: """Create a common AOT launch-config function for pointwise kernels. The generated function computes ``N = inputs[input_index].shape_expr.numel()`` @@ -160,7 +164,7 @@ def pointwise_aot( if input_index < 0: raise ValueError(f"input_index must be >= 0, got {input_index}") - def _aot(inputs, outputs, tactic): + def _aot(inputs: Any, outputs: Any, tactic: Any) -> Any: import tensorrt.plugin as trtp if input_index >= len(inputs): @@ -186,7 +190,7 @@ def pointwise_eager( compile_std: str = "c++17", arch_override: Optional[str] = None, block_size: int = 256, -) -> Callable: +) -> Callable[..., Any]: """Create an eager CUDA implementation for unary pointwise kernels. The generated function assumes kernel signature: @@ -207,7 +211,7 @@ def pointwise_eager( resolved_include_paths = ( include_paths if include_paths is not None else _default_cuda_include_paths() ) - runtime_cache = {} + runtime_cache: dict[str, Any] = {} def _ensure_compiled() -> None: if runtime_cache: @@ -231,7 +235,7 @@ def _ensure_compiled() -> None: runtime_cache["launch"] = launch runtime_cache["LaunchConfig"] = LaunchConfig - def _eager(x): + def _eager(x: Any) -> Any: import torch _ensure_compiled() @@ -244,7 +248,7 @@ def _eager(x): grid = max(1, (n + block_size - 1) // block_size) class _PTStream: - def __cuda_stream__(self): + def __cuda_stream__(self) -> tuple[int, int]: return (0, torch.cuda.current_stream().cuda_stream) device = runtime_cache["device"] @@ -269,17 +273,17 @@ def manual_cuda_kernel_plugin( op_name: str, kernel_source: str, kernel_name: str, - aot_fn: Callable, - eager_fn: Optional[Callable] = None, + aot_fn: Callable[..., Any], + eager_fn: Optional[Callable[..., Any]] = None, include_paths: Optional[List[str]] = None, compile_std: str = "c++17", arch_override: Optional[str] = None, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, - capability_validator: Optional[Callable] = None, + capability_validator: Optional[Callable[..., Any]] = None, schema: Optional[str] = None, -) -> Callable: +) -> Callable[..., Any]: """One-shot decorator for CUDA kernel + custom plugin registration. This is a convenience wrapper equivalent to ``cuda_python(...)`` followed by diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py b/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py index 6a3e0e4746..d66c8ae14f 100644 --- a/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py +++ b/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py @@ -2,13 +2,16 @@ import inspect import logging -from typing import Callable, List, Optional, get_type_hints +from typing import Any, Callable, Dict, List, Optional, get_type_hints import torch -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority -from torch_tensorrt.dynamo.conversion.plugins import generate_plugin, generate_plugin_converter from torch_tensorrt.annotation._specs import CudaPythonSpec +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority +from torch_tensorrt.dynamo.conversion.plugins import ( + generate_plugin, + generate_plugin_converter, +) _LOGGER = logging.getLogger(__name__) @@ -49,7 +52,7 @@ def _patch_trt_shape_expr_reflected_ops() -> None: } -def _infer_schema(fn: Callable) -> str: +def _infer_schema(fn: Callable[..., Any]) -> str: """Derive a TorchScript schema like '(Tensor x, int n) -> Tensor' from type hints.""" try: hints = get_type_hints(fn) @@ -68,9 +71,7 @@ def _infer_schema(fn: Callable) -> str: origin = getattr(ret, "__origin__", None) if origin is tuple: ret_str = "({})".format( - ", ".join( - _TORCH_TYPE_TO_SCHEMA.get(t, "Tensor") for t in ret.__args__ - ) + ", ".join(_TORCH_TYPE_TO_SCHEMA.get(t, "Tensor") for t in ret.__args__) ) else: ret_str = _TORCH_TYPE_TO_SCHEMA.get(ret, "Tensor") @@ -88,8 +89,8 @@ def _torch_op_already_registered(op_name: str) -> bool: def _register_pytorch_op( op_name: str, - meta_fn: Callable, - eager_fn: Optional[Callable], + meta_fn: Callable[..., Any], + eager_fn: Optional[Callable[..., Any]], schema: Optional[str] = None, ) -> None: """Register a new PyTorch custom op using torch.library.Library. @@ -118,9 +119,10 @@ def _register_pytorch_op( def _register_aot_impl(op_name: str, ptx: bytes, spec: CudaPythonSpec) -> None: """Dynamically build a correctly-typed aot_impl and register it with trtp.""" - import tensorrt.plugin as trtp from typing import Tuple, Union # noqa: F401 – used in annotations dict + import tensorrt.plugin as trtp + ns, name = op_name.split("::") torch_op = getattr(getattr(torch.ops, ns), name) schema = torch_op._schemas[""] @@ -157,17 +159,15 @@ def _aot_impl({sig}): "_ptx_str": ptx_str, "_trtp": trtp, } - local_ns: dict = {} + local_ns: Dict[str, Any] = {} exec(compile(fn_body, "", "exec"), fn_globals, local_ns) aot_fn = local_ns["_aot_impl"] - aot_fn.__annotations__ = { - n: trtp.TensorDesc for n in tensor_arg_names - } - aot_fn.__annotations__["outputs"] = Tuple[trtp.TensorDesc] # type: ignore[name-defined] + aot_fn.__annotations__ = dict.fromkeys(tensor_arg_names, trtp.TensorDesc) + aot_fn.__annotations__["outputs"] = Tuple[trtp.TensorDesc] aot_fn.__annotations__["tactic"] = int - aot_fn.__annotations__["return"] = Tuple[ # type: ignore[name-defined] - Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs # type: ignore[name-defined] + aot_fn.__annotations__["return"] = Tuple[ + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs ] trtp.aot_impl(op_name)(aot_fn) @@ -177,11 +177,11 @@ def _aot_impl({sig}): def register_cuda_python_plugin( op_name: str, spec: CudaPythonSpec, - meta_fn: Optional[Callable], + meta_fn: Optional[Callable[..., Any]], supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, - capability_validator: Optional[Callable] = None, + capability_validator: Optional[Callable[..., Any]] = None, register_torch_op: bool = True, schema: Optional[str] = None, precompiled_ptx: Optional[bytes] = None, diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py b/py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py index ec4a706404..9fa95cc49a 100644 --- a/py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py +++ b/py/torch_tensorrt/annotation/_custom_plugin/_nvrtc.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple _LOGGER = logging.getLogger(__name__) @@ -34,7 +34,7 @@ def _cuda_core_imports() -> Tuple[Any, Any, Any, Any, Any]: def compile_to_ptx( kernel_source: str, kernel_name: str, - include_paths: list, + include_paths: List[str], compile_std: str = "c++17", arch_override: Optional[str] = None, ) -> Tuple[bytes, Any, Any]: diff --git a/py/torch_tensorrt/annotation/_kernel_plugin.py b/py/torch_tensorrt/annotation/_kernel_plugin.py index 32ad484db1..8cc83c9eb1 100644 --- a/py/torch_tensorrt/annotation/_kernel_plugin.py +++ b/py/torch_tensorrt/annotation/_kernel_plugin.py @@ -9,16 +9,16 @@ and hands them off to :func:`register_cuda_python_plugin`. """ + from __future__ import annotations import logging import textwrap -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority from torch_tensorrt.annotation._kernel_spec import ( Custom, DimSize, @@ -37,6 +37,7 @@ CudaPythonSpec, _default_cuda_include_paths, ) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority _LOGGER = logging.getLogger(__name__) @@ -50,7 +51,7 @@ def _resolve_axis(axis: int, ndim: int) -> int: return axis if axis >= 0 else ndim + axis -def _tensor_input_decls(inputs) -> List[InputDecl]: +def _tensor_input_decls(inputs: Sequence[Any]) -> List[InputDecl]: """Return only the tensor inputs, preserving order.""" return [d for d in inputs if isinstance(d, InputDecl)] @@ -107,7 +108,7 @@ def _cdiv_int(a: int, b: int) -> int: def _compute_eager_launch( - geom, + geom: Any, outputs: List[torch.Tensor], inputs: List[torch.Tensor], ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int]]: @@ -169,10 +170,10 @@ def _compute_eager_launch( def _compute_aot_launch( - geom, - inputs_desc, - outputs_desc, -): + geom: Any, + inputs_desc: Any, + outputs_desc: Any, +) -> Any: """Return a ``trtp.KernelLaunchParams`` matching the eager launch.""" import tensorrt.plugin as trtp @@ -263,7 +264,7 @@ def _pack_extras_eager( return out -def _pack_extras_aot(extras: List[ExtraArg], input_desc_by_name): +def _pack_extras_aot(extras: List[ExtraArg], input_desc_by_name: Dict[str, Any]) -> Any: import tensorrt.plugin as trtp if not extras: @@ -290,7 +291,7 @@ def _pack_extras_aot(extras: List[ExtraArg], input_desc_by_name): _PYTYPE_TO_SCHEMA = {float: "float", int: "int", bool: "bool"} -def _schema_type_for_input(decl) -> str: +def _schema_type_for_input(decl: Any) -> str: if isinstance(decl, ScalarInput): if decl.py_type not in _PYTYPE_TO_SCHEMA: raise ValueError( @@ -301,7 +302,7 @@ def _schema_type_for_input(decl) -> str: return "Tensor" -def _annot_type_for_input(decl): +def _annot_type_for_input(decl: Any) -> Any: if isinstance(decl, ScalarInput): return decl.py_type return torch.Tensor @@ -316,7 +317,9 @@ def _build_schema(spec: KernelSpec) -> str: return f"({args}) -> {ret}" -def _make_positional_fn(fn: Callable, input_decls) -> Callable: +def _make_positional_fn( + fn: Callable[..., Any], input_decls: Sequence[Any] +) -> Callable[..., Any]: """Wrap ``fn`` so it carries a proper positional signature. ``torch.library.Library.impl`` and ``register_fake`` both introspect the @@ -341,9 +344,9 @@ def _wrapper({sig_src}) -> 'torch.Tensor': return _fn({", ".join(param_names)}) """ ) - ns: dict = {"_fn": fn, "torch": torch} + ns: Dict[str, Any] = {"_fn": fn, "torch": torch} exec(compile(body, "", "exec"), ns) - wrapper = ns["_wrapper"] + wrapper: Callable[..., Any] = ns["_wrapper"] wrapper.__annotations__ = dict(annotations) wrapper.__annotations__["return"] = torch.Tensor return wrapper @@ -354,7 +357,7 @@ def _wrapper({sig_src}) -> 'torch.Tensor': # ============================================================================= -def _compile_kernel(spec: KernelSpec): +def _compile_kernel(spec: KernelSpec) -> Tuple[bytes, Any, Any]: """Compile spec.kernel_source to PTX + get a loadable kernel object.""" try: from cuda.core import ( @@ -377,7 +380,9 @@ def _compile_kernel(spec: KernelSpec): if spec.include_paths is not None else _default_cuda_include_paths() ) - options = ProgramOptions(std=spec.compile_std, arch=arch, include_path=include_paths) + options = ProgramOptions( + std=spec.compile_std, arch=arch, include_path=include_paths + ) program = Program(spec.kernel_source, code_type="c++", options=options) module = program.compile("ptx", name_expressions=(spec.kernel_name,)) ptx: bytes = module.code @@ -390,7 +395,9 @@ def _compile_kernel(spec: KernelSpec): # ============================================================================= -def _split_inputs(all_args, input_specs): +def _split_inputs( + all_args: Sequence[Any], input_specs: Sequence[Any] +) -> Tuple[List[torch.Tensor], List[InputDecl], Dict[str, object]]: """Partition positional args into (tensor_inputs, scalar_values) aligned with the tensor InputDecls and ScalarInputs respectively. """ @@ -406,11 +413,11 @@ def _split_inputs(all_args, input_specs): return tensors, tensor_decls, scalars -def _make_meta_fn(spec: KernelSpec) -> Callable: +def _make_meta_fn(spec: KernelSpec) -> Callable[..., Any]: input_specs = list(spec.inputs) output_decls = list(spec.outputs) - def _meta(*args): + def _meta(*args: Any) -> Any: tensors, tensor_decls, _scalars = _split_inputs(args, input_specs) device = tensors[0].device if tensors else torch.device("cuda") outs = [] @@ -422,7 +429,9 @@ def _meta(*args): return _make_positional_fn(_meta, input_specs) -def _make_eager_fn(spec: KernelSpec, kernel_obj, device) -> Callable: +def _make_eager_fn( + spec: KernelSpec, kernel_obj: Any, device: Any +) -> Callable[..., Any]: try: from cuda.core import LaunchConfig from cuda.core import launch as cuda_launch @@ -435,10 +444,10 @@ def _make_eager_fn(spec: KernelSpec, kernel_obj, device) -> Callable: extras = list(spec.extras) class _PTStream: - def __cuda_stream__(self): # noqa: D401 + def __cuda_stream__(self) -> Tuple[int, int]: # noqa: D401 return (0, torch.cuda.current_stream().cuda_stream) - def _eager(*args): + def _eager(*args: Any) -> Any: tensors, tensor_decls, _scalars = _split_inputs(args, input_specs) outs: List[torch.Tensor] = [] @@ -453,7 +462,7 @@ def _eager(*args): # Kernel arg order: inputs in declaration order (ptr for tensors, # value for scalars), then extras, then output pointers. - arg_list: list = [] + arg_list: List[Any] = [] for decl, val in zip(input_specs, args): if isinstance(decl, ScalarInput): arg_list.append(_coerce_scalar(val, decl.py_type)) @@ -469,7 +478,7 @@ def _eager(*args): return _make_positional_fn(_eager, input_specs) -def _coerce_scalar(value, py_type): +def _coerce_scalar(value: Any, py_type: Any) -> Any: """Convert a Python scalar to the ctypes type that cuda.core needs to forward it by value to the kernel. """ @@ -484,11 +493,11 @@ def _coerce_scalar(value, py_type): raise TypeError(f"Unsupported ScalarInput.py_type: {py_type!r}") -def _make_aot_fn(spec: KernelSpec) -> Callable: +def _make_aot_fn(spec: KernelSpec) -> Callable[..., Any]: tensor_input_decls = _tensor_input_decls(spec.inputs) extras = list(spec.extras) - def _aot(inputs, outputs, tactic): + def _aot(inputs: Any, outputs: Any, tactic: Any) -> Any: # TRT plugin inputs are the tensor-typed args only. The trtp layer # slots them in by (tensor) arg order, so our tensor_input_decls list # aligns with inputs positionally. @@ -589,9 +598,7 @@ def _validate_spec(spec: KernelSpec) -> None: ) elif isinstance(geom, Reduction): if geom.block_size <= 0: - raise ValueError( - f"Reduction.block_size must be > 0, got {geom.block_size}" - ) + raise ValueError(f"Reduction.block_size must be > 0, got {geom.block_size}") if not geom.reduce_dims: raise ValueError("Reduction.reduce_dims must be non-empty") elif isinstance(geom, Custom): @@ -613,7 +620,7 @@ def auto_cuda_kernel_plugin( supports_dynamic_shapes: bool = True, requires_output_allocator: bool = False, priority: ConverterPriority = ConverterPriority.STANDARD, - capability_validator: Optional[Callable] = None, + capability_validator: Optional[Callable[..., Any]] = None, ) -> None: """Register a CUDA kernel described by a :class:`KernelSpec` end-to-end. @@ -673,4 +680,6 @@ def auto_cuda_kernel_plugin( schema=schema, precompiled_ptx=ptx, ) - _LOGGER.info("auto_cuda_kernel_plugin '%s' registered (schema: %s)", op_name, schema) + _LOGGER.info( + "auto_cuda_kernel_plugin '%s' registered (schema: %s)", op_name, schema + ) diff --git a/py/torch_tensorrt/annotation/_kernel_spec.py b/py/torch_tensorrt/annotation/_kernel_spec.py index 94c6bc19fb..ed188aaf1b 100644 --- a/py/torch_tensorrt/annotation/_kernel_spec.py +++ b/py/torch_tensorrt/annotation/_kernel_spec.py @@ -4,14 +4,14 @@ construction. Derivation of meta / eager / aot / schema happens in ``_kernel_plugin.py``. """ + from __future__ import annotations from dataclasses import dataclass -from typing import Callable, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple, Union import torch - # ---------- ShapeRel: how to derive output shape from inputs ---------- @@ -98,7 +98,7 @@ class Custom: as today's hand-written aot_fn: ``(KernelLaunchParams, SymExprs)``. """ - fn: Callable + fn: Callable[..., Any] Geometry = Union[Elementwise, Reduction, Custom] diff --git a/py/torch_tensorrt/annotation/_specs.py b/py/torch_tensorrt/annotation/_specs.py index 7d00ba1f26..b198e055c5 100644 --- a/py/torch_tensorrt/annotation/_specs.py +++ b/py/torch_tensorrt/annotation/_specs.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass, field -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional def _default_cuda_include_paths() -> List[str]: @@ -23,8 +23,8 @@ class CudaPythonSpec: kernel_source: str kernel_name: str - aot_fn: Optional[Callable] - eager_fn: Optional[Callable] = None + aot_fn: Optional[Callable[..., Any]] + eager_fn: Optional[Callable[..., Any]] = None include_paths: List[str] = field(default_factory=_default_cuda_include_paths) compile_std: str = "c++17" arch_override: Optional[str] = None diff --git a/tests/py/annotation/conftest.py b/tests/py/annotation/conftest.py index 11e589f754..ab2b9cdad7 100644 --- a/tests/py/annotation/conftest.py +++ b/tests/py/annotation/conftest.py @@ -1,10 +1,11 @@ """Shared CUDA kernel sources, skip marks, and helpers for annotation tests.""" + from __future__ import annotations import pytest import torch -import torch_tensorrt +import torch_tensorrt skip_no_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA device required" diff --git a/tests/py/annotation/test_auto_cuda_kernel_plugin.py b/tests/py/annotation/test_auto_cuda_kernel_plugin.py index 2b67c8b38f..3e4ac30aa9 100644 --- a/tests/py/annotation/test_auto_cuda_kernel_plugin.py +++ b/tests/py/annotation/test_auto_cuda_kernel_plugin.py @@ -1,6 +1,8 @@ """End-to-end tests for torch_tensorrt.annotation.auto_cuda_kernel_plugin.""" + import pytest import torch + import torch_tensorrt import torch_tensorrt.annotation as tta from torch_tensorrt.annotation._kernel_plugin import ( @@ -19,7 +21,6 @@ skip_no_qdp, ) - # ---- No-GPU: validation & shape inference ---- diff --git a/tests/py/annotation/test_manual_cuda_kernel_plugin.py b/tests/py/annotation/test_manual_cuda_kernel_plugin.py index d92912365a..4694bfda8b 100644 --- a/tests/py/annotation/test_manual_cuda_kernel_plugin.py +++ b/tests/py/annotation/test_manual_cuda_kernel_plugin.py @@ -1,6 +1,8 @@ """Tests for manual_cuda_kernel_plugin and its lower-level building blocks.""" + import pytest import torch + import torch_tensorrt import torch_tensorrt.annotation as tta from torch_tensorrt.annotation._custom_plugin._descriptor import _infer_schema @@ -13,7 +15,6 @@ skip_no_qdp, ) - # ---- No-GPU: CudaPythonSpec construction ---- @@ -27,7 +28,8 @@ def test_basic_construction(self): def test_overrides(self): spec = tta.cuda_python( - "// s", "k", + "// s", + "k", include_paths=["/opt/cuda/include"], arch_override="sm_90", ) @@ -265,9 +267,7 @@ def forward(self, x): for size in [128, 512, 2048]: x = torch.randn(1, size, device="cuda") with torch.no_grad(): - assert torch.allclose( - trt(x), torch.sigmoid(x), atol=1e-2, rtol=1e-2 - ) + assert torch.allclose(trt(x), torch.sigmoid(x), atol=1e-2, rtol=1e-2) @skip_no_cuda @@ -279,9 +279,11 @@ def test_schema_override_integration(): const float* x, int n, float alpha, float* y) {} """ spec = tta.cuda_python( - src, "schema_ov_noop", + src, + "schema_ov_noop", aot_fn=pointwise_aot(), - eager_fn=lambda x, alpha: alpha * x, # reference impl — doesn't touch the kernel + eager_fn=lambda x, alpha: alpha + * x, # reference impl — doesn't touch the kernel ) @tta.custom_plugin(