diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 02b6eb6377..67bc8f1d99 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -219,6 +219,34 @@ def aten_ops_native_group_norm( ) +@dynamo_tensorrt_converter( + torch.ops.aten._fused_rms_norm.default, + supports_dynamic_shapes=True, +) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_fused_rms_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.fused_rms_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + normalized_shape=args[1], + weight=args_bounds_check(args, 2), + eps=args_bounds_check(args, 3), + ) + + def parse_cat_args( args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Tuple[List[Any], int]: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f12b16b150..cfd47af475 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -854,3 +854,89 @@ def cdist_forward( return_indices=False, ) return dist + + +def fused_rms_norm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: trt.ITensor, + normalized_shape: List[int], + weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], + eps: Optional[float], +) -> Tuple[trt.ITensor, torch.Tensor]: + """ + RMS Normalization: output = input / sqrt(mean(input^2) + eps) * weight + + Args: + ctx: ConversionContext containing the TensorRT network + target: Target of calling node + source_ir: SourceIR of calling converter + name: Name of the calling layer + input: Input tensor to normalize + normalized_shape: Shape over which to normalize (list of ints) + weight: Optional weight/scale parameter + eps: Epsilon for numerical stability (default: 1e-5) + + Returns: + Tuple of (normalized_output, rstd_placeholder) + Note: rstd (reciprocal standard deviation) is returned as None placeholder + """ + if eps is None: + eps = 1e-5 + + # Calculate dimensions to normalize over (similar to layer_norm) + # normalized_shape specifies the last N dimensions + dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) + axes = get_axes_for_reduce_op(dims) + + # Square the input + input_squared = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_input_squared", input, input + ) + + # Compute mean of squared values + mean_squared = impl.reduce.mean( + ctx, target, source_ir, f"{name}_mean_squared", input_squared, dim=dims, keepdim=True + ) + + # Add epsilon for numerical stability + eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", input.dtype) + mean_squared_eps = impl.elementwise.add( + ctx, target, source_ir, f"{name}_mean_squared_eps", mean_squared, eps_tensor + ) + + # Compute RMS = sqrt(mean(input^2) + eps) + rms = impl.unary.sqrt(ctx, target, source_ir, f"{name}_rms", mean_squared_eps) + + # Normalize: input / rms + normalized = impl.elementwise.div( + ctx, target, source_ir, f"{name}_normalized", input, rms + ) + + # Apply weight (scale) if provided + if weight is not None: + weight_trt = get_trt_tensor(ctx, weight, f"{name}_weight") + + # Cast weight to match input dtype + weight_trt = cast_trt_tensor( + ctx, weight_trt, input.dtype, f"{name}_weight_cast", target, source_ir + ) + + # Expand weight to match input shape if needed + if tuple(input.shape) != tuple(weight_trt.shape): + weight_trt = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_weight", weight_trt, input.shape + ) + + # Multiply normalized output by weight + output = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_output", normalized, weight_trt + ) + else: + output = normalized + + # Return (output, rstd_placeholder) + # PyTorch returns (output, rstd) but we return None for rstd as it's typically not used + return output, None diff --git a/tests/py/dynamo/conversion/test_rms_norm_aten.py b/tests/py/dynamo/conversion/test_rms_norm_aten.py new file mode 100644 index 0000000000..868994829b --- /dev/null +++ b/tests/py/dynamo/conversion/test_rms_norm_aten.py @@ -0,0 +1,286 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestFusedRMSNormConverter(DispatchTestCase): + """ + Tests for the aten._fused_rms_norm.default converter. + + RMS Normalization formula: output = input / sqrt(mean(input^2) + eps) * weight + + The operation signature is: _fused_rms_norm(input, normalized_shape, weight, eps) + Returns: (output, rstd) - where rstd is the reciprocal standard deviation + """ + + @parameterized.expand( + [ + # Test normalizing over last dimension + ("1d_last_dim", (2, 4, 8), [8]), + # Test normalizing over last 2 dimensions + ("2d_last_two_dims", (2, 4, 8), [4, 8]), + # Test normalizing over all dimensions + ("3d_all_dims", (2, 4, 8), [2, 4, 8]), + # Test with 4D tensor, last dimension + ("4d_last_dim", (2, 3, 4, 8), [8]), + # Test with 4D tensor, last 2 dimensions + ("4d_last_two_dims", (2, 3, 4, 8), [4, 8]), + # Test with 4D tensor, last 3 dimensions + ("4d_last_three_dims", (2, 3, 4, 8), [3, 4, 8]), + ] + ) + def test_rms_norm_with_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm with weight parameter across various tensor shapes. + This tests: + - Correct dimension calculation for normalization + - Weight broadcasting/expansion to match input shape + - Output correctness vs PyTorch reference + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-5 + )[0] # Return only the normalized output, not rstd + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test without weight (None) + ("1d_no_weight", (2, 4, 8), [8]), + ("2d_no_weight", (2, 4, 8), [4, 8]), + ("4d_no_weight", (2, 3, 4, 8), [8]), + ] + ) + def test_rms_norm_without_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm without weight parameter (weight=None). + This ensures the converter handles optional weight correctly. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, None, 1e-5 + )[0] + + inputs = [torch.randn(input_shape)] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test different epsilon values + ("eps_1e5", (2, 4, 8), [8], 1e-5), + ("eps_1e6", (2, 4, 8), [8], 1e-6), + ("eps_1e4", (2, 4, 8), [8], 1e-4), + ] + ) + def test_rms_norm_different_eps(self, name, input_shape, normalized_shape, eps): + """ + Test RMS norm with different epsilon values. + Epsilon is critical for numerical stability, especially with small values. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, eps + )[0] + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_batch(self): + """ + Test RMS norm with dynamic batch dimension. + This is common in inference scenarios where batch size varies. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [128], weight, 1e-6 + )[0] + + input_specs = [ + Input( + shape=(-1, 128), + dtype=torch.float32, + shape_ranges=[((1, 128), (4, 128), (8, 128))], + ), + Input( + shape=(128,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_sequence(self): + """ + Test RMS norm with dynamic sequence length. + This is critical for transformer models with variable sequence lengths. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [256], weight, 1e-5 + )[0] + + input_specs = [ + Input( + shape=(2, -1, 256), + dtype=torch.float32, + shape_ranges=[((2, 16, 256), (2, 64, 256), (2, 128, 256))], + ), + Input( + shape=(256,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_multi_dim(self): + """ + Test RMS norm with multiple dynamic dimensions. + Tests both batch and sequence length being dynamic simultaneously. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [64], weight, 1e-6 + )[0] + + input_specs = [ + Input( + shape=(-1, -1, 64), + dtype=torch.float32, + shape_ranges=[((1, 8, 64), (4, 16, 64), (8, 32, 64))], + ), + Input( + shape=(64,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_2d_input(self): + """ + Test RMS norm with 2D input (batch, features). + Common in MLP layers or simple feedforward networks. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [512], weight, 1e-5 + )[0] + + inputs = [ + torch.randn(32, 512), + torch.randn(512), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_large_hidden_dim(self): + """ + Test RMS norm with larger hidden dimensions typical in modern LLMs. + Tests numerical stability and performance with realistic model sizes. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [4096], weight, 1e-6 + )[0] + + inputs = [ + torch.randn(2, 8, 4096), + torch.randn(4096), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_flux_pattern(self): + """ + Test RMS norm with pattern similar to FLUX and modern diffusion models. + This tests the actual use case that motivated the converter implementation. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + # FLUX-style: normalize over last dimension with small epsilon + normalized_shape = [x.shape[-1]] + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-6 + )[0] + + inputs = [ + torch.randn(1, 16, 3072), # Typical FLUX dimensions + torch.randn(3072), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests()