From 81c6507f85aca24924646c4ae88375319db2163b Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Thu, 8 Jan 2026 13:22:30 -0800 Subject: [PATCH 1/4] Improve documentation after trying on a new machine. --- docsrc/getting_started/tensorrt_rtx.rst | 158 ++++++++++++++++++++++-- 1 file changed, 150 insertions(+), 8 deletions(-) diff --git a/docsrc/getting_started/tensorrt_rtx.rst b/docsrc/getting_started/tensorrt_rtx.rst index 0c474fc89f..353d82a5ed 100644 --- a/docsrc/getting_started/tensorrt_rtx.rst +++ b/docsrc/getting_started/tensorrt_rtx.rst @@ -17,7 +17,7 @@ For detailed information about TensorRT-RTX, refer to: * `TensorRT-RTX Documentation `_ -Currently, Torch-TensorRT only supports TensorRT-RTX for experimental purposes. +Currently, Torch-TensorRT only supports TensorRT-RTX for experimental purposes. Torch-TensorRT by default uses standard TensorRT during the build and run. To use TensorRT-RTX: @@ -28,6 +28,84 @@ To use TensorRT-RTX: Prerequisites ------------- +Clone the Repository +~~~~~~~~~~~~~~~~~~~~~ + +First, clone the Torch-TensorRT repository: + +.. code-block:: sh + + git clone https://github.com/pytorch/TensorRT.git + cd TensorRT + +Install System Dependencies +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**In Linux:** + +Install Python development headers (required for building Python extensions): + +.. code-block:: sh + + # For Python 3.12 (adjust version number based on your Python version) + sudo apt install python3.12-dev + +Install CUDA Toolkit +~~~~~~~~~~~~~~~~~~~~ + +Download and install the CUDA Toolkit from the `NVIDIA Developer website `_. + +**Important:** Check the required CUDA version in the `MODULE.bazel `_ file. You must install the exact CUDA toolkit version specified there (for example, at the time of writing, CUDA 13.0 is required). + +After installation, set the ``CUDA_HOME`` environment variable: + +.. code-block:: sh + + export CUDA_HOME=/usr/local/cuda + # Add this to your ~/.bashrc or ~/.zshrc to make it persistent + +Install Python Dependencies +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**It is strongly recommended to use a virtual environment** to avoid conflicts with system packages: + +.. code-block:: sh + + # Create a virtual environment + python -m venv .venv + + # Activate the virtual environment + source .venv/bin/activate # On Linux/Mac + # OR on Windows: + # .venv\Scripts\activate + +Before building, install the required Python packages: + +.. code-block:: sh + + # Install setuptools (provides distutils) + pip install setuptools + + # Install PyTorch nightly build (check CUDA version in MODULE.bazel) + # Replace cuXXX with your CUDA version (e.g., cu130 for CUDA 13.0) + pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cuXXX + + # If you encounter version conflicts during the build, you may need to specify + # the exact PyTorch version constraint. Check pyproject.toml for requirements. + # For example, if pyproject.toml specifies torch>=2.10.0.dev,<2.11.0: + # pip install --pre "torch>=2.10.0.dev,<2.11.0" torchvision --index-url https://download.pytorch.org/whl/nightly/cu130 + + # Install additional build dependencies + pip install pyyaml numpy + +.. note:: + + The PyTorch version requirement is defined in `pyproject.toml `_ (build requirements) and `setup.py `_ (runtime requirements). If you encounter version-related errors during installation, refer to these files for the exact version constraints. + +.. note:: + + Remember to activate the virtual environment (``source .venv/bin/activate``) whenever you work with this project or run the build commands. + Install Bazel ~~~~~~~~~~~~~ @@ -51,7 +129,7 @@ Bazel is required to build the wheel with TensorRT-RTX. Install TensorRT-RTX Tarball ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -TensorRT-RTX tarball can be downloaded from https://developer.nvidia.com/tensorrt-rtx. +TensorRT-RTX tarball can be downloaded from https://developer.nvidia.com/tensorrt-rtx. Currently, Torch-TensorRT uses TensorRT-RTX version **1.2.0.54**. Once downloaded: @@ -79,7 +157,7 @@ Make sure you add the lib path to the Windows system variable ``PATH``. Install TensorRT-RTX Wheel ~~~~~~~~~~~~~~~~~~~~~~~~~~ -Currently, the `tensorrt_rtx` wheel is not published on PyPI. +Currently, the `tensorrt_rtx` wheel is not published on PyPI. You must install it manually from the downloaded tarball. .. code-block:: sh @@ -93,20 +171,31 @@ Build Torch-TensorRT with TensorRT-RTX Build Locally with TensorRT-RTX ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Before building, ensure you have completed all the prerequisite steps above, including: + +- Cloning the repository +- Installing Python dependencies (setuptools, torch, pyyaml, numpy) +- Setting CUDA_HOME environment variable +- Installing the correct CUDA toolkit version +- Installing Python development headers +- Installing Bazel + +Then build the wheel: + .. code-block:: sh # If you have previously built with standard TensorRT, make sure to clean the build environment, # otherwise it will use the existing .so built with standard TensorRT, which is not compatible with TensorRT-RTX. python setup.py clean bazel clean --expunge - #remove everything under build directory, + # Remove everything under build directory rm -rf build/* # Build wheel with TensorRT-RTX python setup.py bdist_wheel --use-rtx - # Install the wheel - python -m pip install dist/torch-tensorrt-*.whl + # Install the wheel (note: the wheel filename uses underscores, not hyphens) + python -m pip install dist/torch_tensorrt-*.whl Quick Start ----------- @@ -119,7 +208,60 @@ Quick Start Troubleshooting --------------- -If you encounter load or link errors, check if `tensorrt_rtx` is linked correctly. +Common Issues +~~~~~~~~~~~~~ + +**Missing distutils module** + +If you encounter ``ModuleNotFoundError: No module named 'distutils'``, install setuptools: + +.. code-block:: sh + + pip install setuptools + +**Missing CUDA_HOME environment variable** + +If you encounter ``OSError: CUDA_HOME environment variable is not set``, set the CUDA_HOME path: + +.. code-block:: sh + + export CUDA_HOME=/usr/local/cuda + +**CUDA version mismatch** + +If you encounter errors about CUDA paths not existing (e.g., ``/usr/local/cuda-X.Y/ does not exist``), ensure you have the correct CUDA version installed. Check the required version in `MODULE.bazel `_. You may need to: + +1. Update your NVIDIA drivers +2. Download and install the specific CUDA toolkit version required by MODULE.bazel +3. Clean and rebuild after installing the correct version + +**PyTorch version mismatch** + +If you encounter an error like ``ERROR: No matching distribution found for torch=X.Y.Z.dev`` (for example, ``torch<2.11.0,>=2.10.0.dev``), you need to install a compatible PyTorch nightly version. + +First, check the exact version constraint in `pyproject.toml `_, then install with that constraint: + +.. code-block:: sh + + # Example: if pyproject.toml requires torch>=2.10.0.dev,<2.11.0 + # and MODULE.bazel specifies CUDA 13.0 (cu130): + pip install --pre "torch>=2.10.0.dev,<2.11.0" torchvision --index-url https://download.pytorch.org/whl/nightly/cu130 + +Replace the version constraint and CUDA version (cuXXX) according to your project's requirements. + +**Missing Python development headers** + +If you encounter ``fatal error: Python.h: No such file or directory``, install the Python development package: + +.. code-block:: sh + + # For Python 3.12 (adjust version based on your Python) + sudo apt install python3.12-dev + +Verifying TensorRT-RTX Linkage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you encounter load or link errors, check if `tensorrt_rtx` is linked correctly. If not, clean up the environment and rebuild. **In Linux:** @@ -127,7 +269,7 @@ If not, clean up the environment and rebuild. .. code-block:: sh # Ensure only tensorrt_rtx is installed (no standard tensorrt wheels) - python -m pip list | grep tensorrt + python -m pip list | grep tensorrt # Check if libtorchtrt.so links to the correct tensorrt_rtx shared object trt_install_path=$(python -m pip show torch-tensorrt | grep "Location" | awk '{print $2}')/torch_tensorrt From 299b095bb9e3d4e8a7bb40d7ca384e47e8e68f8a Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Thu, 8 Jan 2026 15:16:45 -0800 Subject: [PATCH 2/4] Clarified instructions. --- docsrc/getting_started/tensorrt_rtx.rst | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docsrc/getting_started/tensorrt_rtx.rst b/docsrc/getting_started/tensorrt_rtx.rst index 353d82a5ed..380b87cc19 100644 --- a/docsrc/getting_started/tensorrt_rtx.rst +++ b/docsrc/getting_started/tensorrt_rtx.rst @@ -86,14 +86,13 @@ Before building, install the required Python packages: # Install setuptools (provides distutils) pip install setuptools + # Install PyTorch + # Note: If you are building Torch-TensorRT at tip-of-tree, you need to install the latest PyTorch nightly build rather than the stable release. See below for details. + pip install torch torchvision + # Install PyTorch nightly build (check CUDA version in MODULE.bazel) # Replace cuXXX with your CUDA version (e.g., cu130 for CUDA 13.0) - pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cuXXX - - # If you encounter version conflicts during the build, you may need to specify - # the exact PyTorch version constraint. Check pyproject.toml for requirements. - # For example, if pyproject.toml specifies torch>=2.10.0.dev,<2.11.0: - # pip install --pre "torch>=2.10.0.dev,<2.11.0" torchvision --index-url https://download.pytorch.org/whl/nightly/cu130 + # pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cuXXX # Install additional build dependencies pip install pyyaml numpy From 90d6cd22bb2d34fb8ad150981671daf5a0270d43 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Wed, 14 Jan 2026 16:49:40 -0800 Subject: [PATCH 3/4] Implement rms norm converter. --- .../dynamo/conversion/aten_ops_converters.py | 28 ++++++ .../conversion/impl/normalization/ops.py | 86 +++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 02b6eb6377..67bc8f1d99 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -219,6 +219,34 @@ def aten_ops_native_group_norm( ) +@dynamo_tensorrt_converter( + torch.ops.aten._fused_rms_norm.default, + supports_dynamic_shapes=True, +) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_fused_rms_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.fused_rms_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + normalized_shape=args[1], + weight=args_bounds_check(args, 2), + eps=args_bounds_check(args, 3), + ) + + def parse_cat_args( args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Tuple[List[Any], int]: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f12b16b150..cfd47af475 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -854,3 +854,89 @@ def cdist_forward( return_indices=False, ) return dist + + +def fused_rms_norm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: trt.ITensor, + normalized_shape: List[int], + weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], + eps: Optional[float], +) -> Tuple[trt.ITensor, torch.Tensor]: + """ + RMS Normalization: output = input / sqrt(mean(input^2) + eps) * weight + + Args: + ctx: ConversionContext containing the TensorRT network + target: Target of calling node + source_ir: SourceIR of calling converter + name: Name of the calling layer + input: Input tensor to normalize + normalized_shape: Shape over which to normalize (list of ints) + weight: Optional weight/scale parameter + eps: Epsilon for numerical stability (default: 1e-5) + + Returns: + Tuple of (normalized_output, rstd_placeholder) + Note: rstd (reciprocal standard deviation) is returned as None placeholder + """ + if eps is None: + eps = 1e-5 + + # Calculate dimensions to normalize over (similar to layer_norm) + # normalized_shape specifies the last N dimensions + dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) + axes = get_axes_for_reduce_op(dims) + + # Square the input + input_squared = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_input_squared", input, input + ) + + # Compute mean of squared values + mean_squared = impl.reduce.mean( + ctx, target, source_ir, f"{name}_mean_squared", input_squared, dim=dims, keepdim=True + ) + + # Add epsilon for numerical stability + eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", input.dtype) + mean_squared_eps = impl.elementwise.add( + ctx, target, source_ir, f"{name}_mean_squared_eps", mean_squared, eps_tensor + ) + + # Compute RMS = sqrt(mean(input^2) + eps) + rms = impl.unary.sqrt(ctx, target, source_ir, f"{name}_rms", mean_squared_eps) + + # Normalize: input / rms + normalized = impl.elementwise.div( + ctx, target, source_ir, f"{name}_normalized", input, rms + ) + + # Apply weight (scale) if provided + if weight is not None: + weight_trt = get_trt_tensor(ctx, weight, f"{name}_weight") + + # Cast weight to match input dtype + weight_trt = cast_trt_tensor( + ctx, weight_trt, input.dtype, f"{name}_weight_cast", target, source_ir + ) + + # Expand weight to match input shape if needed + if tuple(input.shape) != tuple(weight_trt.shape): + weight_trt = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_weight", weight_trt, input.shape + ) + + # Multiply normalized output by weight + output = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_output", normalized, weight_trt + ) + else: + output = normalized + + # Return (output, rstd_placeholder) + # PyTorch returns (output, rstd) but we return None for rstd as it's typically not used + return output, None From 80d63a24f9ee2d541b6e7e0abc7f73a937d46ed3 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Wed, 14 Jan 2026 17:11:45 -0800 Subject: [PATCH 4/4] Add test for rms norm converter. --- .../dynamo/conversion/test_rms_norm_aten.py | 286 ++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_rms_norm_aten.py diff --git a/tests/py/dynamo/conversion/test_rms_norm_aten.py b/tests/py/dynamo/conversion/test_rms_norm_aten.py new file mode 100644 index 0000000000..868994829b --- /dev/null +++ b/tests/py/dynamo/conversion/test_rms_norm_aten.py @@ -0,0 +1,286 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestFusedRMSNormConverter(DispatchTestCase): + """ + Tests for the aten._fused_rms_norm.default converter. + + RMS Normalization formula: output = input / sqrt(mean(input^2) + eps) * weight + + The operation signature is: _fused_rms_norm(input, normalized_shape, weight, eps) + Returns: (output, rstd) - where rstd is the reciprocal standard deviation + """ + + @parameterized.expand( + [ + # Test normalizing over last dimension + ("1d_last_dim", (2, 4, 8), [8]), + # Test normalizing over last 2 dimensions + ("2d_last_two_dims", (2, 4, 8), [4, 8]), + # Test normalizing over all dimensions + ("3d_all_dims", (2, 4, 8), [2, 4, 8]), + # Test with 4D tensor, last dimension + ("4d_last_dim", (2, 3, 4, 8), [8]), + # Test with 4D tensor, last 2 dimensions + ("4d_last_two_dims", (2, 3, 4, 8), [4, 8]), + # Test with 4D tensor, last 3 dimensions + ("4d_last_three_dims", (2, 3, 4, 8), [3, 4, 8]), + ] + ) + def test_rms_norm_with_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm with weight parameter across various tensor shapes. + This tests: + - Correct dimension calculation for normalization + - Weight broadcasting/expansion to match input shape + - Output correctness vs PyTorch reference + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-5 + )[0] # Return only the normalized output, not rstd + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test without weight (None) + ("1d_no_weight", (2, 4, 8), [8]), + ("2d_no_weight", (2, 4, 8), [4, 8]), + ("4d_no_weight", (2, 3, 4, 8), [8]), + ] + ) + def test_rms_norm_without_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm without weight parameter (weight=None). + This ensures the converter handles optional weight correctly. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, None, 1e-5 + )[0] + + inputs = [torch.randn(input_shape)] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test different epsilon values + ("eps_1e5", (2, 4, 8), [8], 1e-5), + ("eps_1e6", (2, 4, 8), [8], 1e-6), + ("eps_1e4", (2, 4, 8), [8], 1e-4), + ] + ) + def test_rms_norm_different_eps(self, name, input_shape, normalized_shape, eps): + """ + Test RMS norm with different epsilon values. + Epsilon is critical for numerical stability, especially with small values. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, eps + )[0] + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_batch(self): + """ + Test RMS norm with dynamic batch dimension. + This is common in inference scenarios where batch size varies. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [128], weight, 1e-6 + )[0] + + input_specs = [ + Input( + shape=(-1, 128), + dtype=torch.float32, + shape_ranges=[((1, 128), (4, 128), (8, 128))], + ), + Input( + shape=(128,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_sequence(self): + """ + Test RMS norm with dynamic sequence length. + This is critical for transformer models with variable sequence lengths. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [256], weight, 1e-5 + )[0] + + input_specs = [ + Input( + shape=(2, -1, 256), + dtype=torch.float32, + shape_ranges=[((2, 16, 256), (2, 64, 256), (2, 128, 256))], + ), + Input( + shape=(256,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_multi_dim(self): + """ + Test RMS norm with multiple dynamic dimensions. + Tests both batch and sequence length being dynamic simultaneously. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [64], weight, 1e-6 + )[0] + + input_specs = [ + Input( + shape=(-1, -1, 64), + dtype=torch.float32, + shape_ranges=[((1, 8, 64), (4, 16, 64), (8, 32, 64))], + ), + Input( + shape=(64,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_2d_input(self): + """ + Test RMS norm with 2D input (batch, features). + Common in MLP layers or simple feedforward networks. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [512], weight, 1e-5 + )[0] + + inputs = [ + torch.randn(32, 512), + torch.randn(512), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_large_hidden_dim(self): + """ + Test RMS norm with larger hidden dimensions typical in modern LLMs. + Tests numerical stability and performance with realistic model sizes. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [4096], weight, 1e-6 + )[0] + + inputs = [ + torch.randn(2, 8, 4096), + torch.randn(4096), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_flux_pattern(self): + """ + Test RMS norm with pattern similar to FLUX and modern diffusion models. + This tests the actual use case that motivated the converter implementation. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + # FLUX-style: normalize over last dimension with small epsilon + normalized_shape = [x.shape[-1]] + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-6 + )[0] + + inputs = [ + torch.randn(1, 16, 3072), # Typical FLUX dimensions + torch.randn(3072), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests()