From 55010b648c796ba6e8a7230ce767400dc319b0a0 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 16 Jan 2026 21:26:25 +0000 Subject: [PATCH 1/2] Added fuse_rms_norm lowering --- .../lowering/passes/_aten_lowering_pass.py | 2 ++ .../lowering/passes/replace_fused_rms_norm.py | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 8ad5f2fcae..402929d6f2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -18,6 +18,7 @@ from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes from .repair_input_as_output import repair_input_as_output +from .replace_fused_rms_norm import replace_fused_rms_norm from .replace_max_pool_with_indices import replace_max_pool_with_indices from .rule_based_autocast import rule_based_autocast @@ -25,6 +26,7 @@ remove_detach, remove_assert_nodes, rule_based_autocast, + replace_fused_rms_norm, ] post_lowering_pass_list = [ diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py new file mode 100644 index 0000000000..78eec1ce02 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py @@ -0,0 +1,27 @@ +import logging + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings + +logger = logging.getLogger(__name__) + + +def replace_fused_rms_norm( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace fused rms norm ops in the graph""" + count = 0 + for node in gm.graph.nodes: + if node.target == torch.ops.aten._fused_rms_norm.default: + # Replace fused rms norm with standard rms norm + new_node = gm.graph.call_function( + torch.ops.aten.rms_norm.default, + args=node.args, + ) + gm.graph.replace_node_with_new_node(node, new_node) + gm.graph.erase_node(node) + count += 1 + + logger.debug(f"Replaced {count} fused rms norm nodes:\n{gm.graph}") + + return gm From 8c9d9e232b8ddf589a38a09b710cdeb4a4bdc38b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 20 Jan 2026 23:18:51 +0000 Subject: [PATCH 2/2] changed lowering pass to post lowering and implemented rms_norm --- .../lowering/passes/_aten_lowering_pass.py | 2 +- .../lowering/passes/replace_fused_rms_norm.py | 72 ++++- .../lowering/test_fused_rms_norm_aten.py | 278 ++++++++++++++++++ 3 files changed, 344 insertions(+), 8 deletions(-) create mode 100644 tests/py/dynamo/lowering/test_fused_rms_norm_aten.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 402929d6f2..08c1cb9144 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -26,10 +26,10 @@ remove_detach, remove_assert_nodes, rule_based_autocast, - replace_fused_rms_norm, ] post_lowering_pass_list = [ + replace_fused_rms_norm, remove_input_alias_fixing_clones, constant_fold, repair_input_as_output, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py index 78eec1ce02..162c00328f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py @@ -1,7 +1,12 @@ +import copy import logging +import operator import torch from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) logger = logging.getLogger(__name__) @@ -13,15 +18,68 @@ def replace_fused_rms_norm( count = 0 for node in gm.graph.nodes: if node.target == torch.ops.aten._fused_rms_norm.default: - # Replace fused rms norm with standard rms norm - new_node = gm.graph.call_function( - torch.ops.aten.rms_norm.default, - args=node.args, - ) - gm.graph.replace_node_with_new_node(node, new_node) - gm.graph.erase_node(node) + new_node = process_fused_rms_norm_node(node, gm) count += 1 logger.debug(f"Replaced {count} fused rms norm nodes:\n{gm.graph}") + gm = clean_up_graph_after_modifications(gm) + return gm + + +def process_fused_rms_norm_node( + node: torch.fx.Node, gm: torch.fx.GraphModule +) -> torch.fx.Node: + + x, shape, weight, eps = node.args[0], node.args[1], node.args[2], node.args[3] + if eps is None: + eps = 1e-5 + # Calculate dimensions to normalize over (similar to layer_norm) + # normalized_shape specifies the last N dimensions + x_dim = len(node.meta["val"][0].shape) + dims_to_reduce = [] + for i in range(len(shape)): + dims_to_reduce.append(x_dim - i - 1) + + with gm.graph.inserting_before(node): + # Replace fused rms norm with standard rms norm + x_squared = gm.graph.call_function( + torch.ops.aten.mul.Tensor, + args=(x, x), + ) + x_squared_sum = gm.graph.call_function( + torch.ops.aten.mean.dim, + args=(x_squared, dims_to_reduce, True), + ) + x_squared_sum_eps = gm.graph.call_function( + torch.ops.aten.add.Tensor, + args=(x_squared_sum, eps), + ) + x_squared_sum_eps_sqrt = gm.graph.call_function( + torch.ops.aten.sqrt.default, + args=(x_squared_sum_eps,), + ) + x_normalized = gm.graph.call_function( + torch.ops.aten.div.Tensor, + args=(x, x_squared_sum_eps_sqrt), + ) + if weight is not None: + x_normalized = gm.graph.call_function( + torch.ops.aten.mul.Tensor, + args=(x_normalized, weight), + ) + + x_normalized.meta = {} + + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if not x_normalized.meta: + x_normalized.meta = copy.copy(node.meta) + user.replace_all_uses_with(x_normalized) + gm.graph.erase_node(user) + + gm.graph.erase_node(node) + + return x_normalized diff --git a/tests/py/dynamo/lowering/test_fused_rms_norm_aten.py b/tests/py/dynamo/lowering/test_fused_rms_norm_aten.py new file mode 100644 index 0000000000..4a4eb6420e --- /dev/null +++ b/tests/py/dynamo/lowering/test_fused_rms_norm_aten.py @@ -0,0 +1,278 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from ..conversion.harness import DispatchTestCase + + +class TestFusedRMSNormConverter(DispatchTestCase): + """ + Tests for the aten._fused_rms_norm.default converter. + RMS Normalization formula: output = input / sqrt(mean(input^2) + eps) * weight + The operation signature is: _fused_rms_norm(input, normalized_shape, weight, eps) + Returns: (output, rstd) - where rstd is the reciprocal standard deviation + """ + + @parameterized.expand( + [ + # Test normalizing over last dimension + ("1d_last_dim", (2, 4, 8), [8]), + # Test normalizing over last 2 dimensions + ("2d_last_two_dims", (2, 4, 8), [4, 8]), + # Test normalizing over all dimensions + ("3d_all_dims", (2, 4, 8), [2, 4, 8]), + # Test with 4D tensor, last dimension + ("4d_last_dim", (2, 3, 4, 8), [8]), + # Test with 4D tensor, last 2 dimensions + ("4d_last_two_dims", (2, 3, 4, 8), [4, 8]), + # Test with 4D tensor, last 3 dimensions + ("4d_last_three_dims", (2, 3, 4, 8), [3, 4, 8]), + ] + ) + def test_rms_norm_with_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm with weight parameter across various tensor shapes. + This tests: + - Correct dimension calculation for normalization + - Weight broadcasting/expansion to match input shape + - Output correctness vs PyTorch reference + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-5 + )[ + 0 + ] # Return only the normalized output, not rstd + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test without weight (None) + ("1d_no_weight", (2, 4, 8), [8]), + ("2d_no_weight", (2, 4, 8), [4, 8]), + ("4d_no_weight", (2, 3, 4, 8), [8]), + ] + ) + def test_rms_norm_without_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm without weight parameter (weight=None). + This ensures the converter handles optional weight correctly. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, None, 1e-5 + )[0] + + inputs = [torch.randn(input_shape)] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test different epsilon values + ("eps_1e5", (2, 4, 8), [8], 1e-5), + ("eps_1e6", (2, 4, 8), [8], 1e-6), + ("eps_1e4", (2, 4, 8), [8], 1e-4), + ] + ) + def test_rms_norm_different_eps(self, name, input_shape, normalized_shape, eps): + """ + Test RMS norm with different epsilon values. + Epsilon is critical for numerical stability, especially with small values. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, eps + )[0] + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_batch(self): + """ + Test RMS norm with dynamic batch dimension. + This is common in inference scenarios where batch size varies. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default(x, [128], weight, 1e-6)[0] + + input_specs = [ + Input( + shape=(-1, 128), + dtype=torch.float32, + shape_ranges=[((1, 128), (4, 128), (8, 128))], + ), + Input( + shape=(128,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_sequence(self): + """ + Test RMS norm with dynamic sequence length. + This is critical for transformer models with variable sequence lengths. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default(x, [256], weight, 1e-5)[0] + + input_specs = [ + Input( + shape=(2, -1, 256), + dtype=torch.float32, + shape_ranges=[((2, 16, 256), (2, 64, 256), (2, 128, 256))], + ), + Input( + shape=(256,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_multi_dim(self): + """ + Test RMS norm with multiple dynamic dimensions. + Tests both batch and sequence length being dynamic simultaneously. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default(x, [64], weight, 1e-6)[0] + + input_specs = [ + Input( + shape=(-1, -1, 64), + dtype=torch.float32, + shape_ranges=[((1, 8, 64), (4, 16, 64), (8, 32, 64))], + ), + Input( + shape=(64,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_2d_input(self): + """ + Test RMS norm with 2D input (batch, features). + Common in MLP layers or simple feedforward networks. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default(x, [512], weight, 1e-5)[0] + + inputs = [ + torch.randn(32, 512), + torch.randn(512), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_large_hidden_dim(self): + """ + Test RMS norm with larger hidden dimensions typical in modern LLMs. + Tests numerical stability and performance with realistic model sizes. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default(x, [4096], weight, 1e-6)[ + 0 + ] + + inputs = [ + torch.randn(2, 8, 4096), + torch.randn(4096), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_flux_pattern(self): + """ + Test RMS norm with pattern similar to FLUX and modern diffusion models. + This tests the actual use case that motivated the converter implementation. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + # FLUX-style: normalize over last dimension with small epsilon + normalized_shape = [x.shape[-1]] + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-6 + )[0] + + inputs = [ + torch.randn(1, 16, 3072), # Typical FLUX dimensions + torch.randn(3072), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests()