Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
from .repair_input_as_output import repair_input_as_output
from .replace_fused_rms_norm import replace_fused_rms_norm
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .rule_based_autocast import rule_based_autocast

Expand All @@ -28,6 +29,7 @@
]

post_lowering_pass_list = [
replace_fused_rms_norm,
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
Expand Down
85 changes: 85 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import copy
import logging
import operator

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def replace_fused_rms_norm(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace fused rms norm ops in the graph"""
count = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten._fused_rms_norm.default:
new_node = process_fused_rms_norm_node(node, gm)
count += 1

logger.debug(f"Replaced {count} fused rms norm nodes:\n{gm.graph}")

gm = clean_up_graph_after_modifications(gm)

return gm


def process_fused_rms_norm_node(
node: torch.fx.Node, gm: torch.fx.GraphModule
) -> torch.fx.Node:

x, shape, weight, eps = node.args[0], node.args[1], node.args[2], node.args[3]
if eps is None:
eps = 1e-5
# Calculate dimensions to normalize over (similar to layer_norm)
# normalized_shape specifies the last N dimensions
x_dim = len(node.meta["val"][0].shape)
dims_to_reduce = []
for i in range(len(shape)):
dims_to_reduce.append(x_dim - i - 1)

with gm.graph.inserting_before(node):
# Replace fused rms norm with standard rms norm
x_squared = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(x, x),
)
x_squared_sum = gm.graph.call_function(
torch.ops.aten.mean.dim,
args=(x_squared, dims_to_reduce, True),
)
x_squared_sum_eps = gm.graph.call_function(
torch.ops.aten.add.Tensor,
args=(x_squared_sum, eps),
)
x_squared_sum_eps_sqrt = gm.graph.call_function(
torch.ops.aten.sqrt.default,
args=(x_squared_sum_eps,),
)
x_normalized = gm.graph.call_function(
torch.ops.aten.div.Tensor,
args=(x, x_squared_sum_eps_sqrt),
)
if weight is not None:
x_normalized = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(x_normalized, weight),
)

x_normalized.meta = {}

for user in list(node.users):
if user.op == "call_function" and user.target == operator.getitem:
# If the getitem is extracting the first element (the output tensor)
if not x_normalized.meta:
x_normalized.meta = copy.copy(node.meta)
user.replace_all_uses_with(x_normalized)
gm.graph.erase_node(user)

gm.graph.erase_node(node)

return x_normalized
278 changes: 278 additions & 0 deletions tests/py/dynamo/lowering/test_fused_rms_norm_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from ..conversion.harness import DispatchTestCase


class TestFusedRMSNormConverter(DispatchTestCase):
"""
Tests for the aten._fused_rms_norm.default converter.
RMS Normalization formula: output = input / sqrt(mean(input^2) + eps) * weight
The operation signature is: _fused_rms_norm(input, normalized_shape, weight, eps)
Returns: (output, rstd) - where rstd is the reciprocal standard deviation
"""

@parameterized.expand(
[
# Test normalizing over last dimension
("1d_last_dim", (2, 4, 8), [8]),
# Test normalizing over last 2 dimensions
("2d_last_two_dims", (2, 4, 8), [4, 8]),
# Test normalizing over all dimensions
("3d_all_dims", (2, 4, 8), [2, 4, 8]),
# Test with 4D tensor, last dimension
("4d_last_dim", (2, 3, 4, 8), [8]),
# Test with 4D tensor, last 2 dimensions
("4d_last_two_dims", (2, 3, 4, 8), [4, 8]),
# Test with 4D tensor, last 3 dimensions
("4d_last_three_dims", (2, 3, 4, 8), [3, 4, 8]),
]
)
def test_rms_norm_with_weight(self, name, input_shape, normalized_shape):
"""
Test RMS norm with weight parameter across various tensor shapes.
This tests:
- Correct dimension calculation for normalization
- Weight broadcasting/expansion to match input shape
- Output correctness vs PyTorch reference
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten._fused_rms_norm.default(
x, normalized_shape, weight, 1e-5
)[
0
] # Return only the normalized output, not rstd

inputs = [
torch.randn(input_shape),
torch.randn(normalized_shape),
]
self.run_test(
RMSNorm(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)

@parameterized.expand(
[
# Test without weight (None)
("1d_no_weight", (2, 4, 8), [8]),
("2d_no_weight", (2, 4, 8), [4, 8]),
("4d_no_weight", (2, 3, 4, 8), [8]),
]
)
def test_rms_norm_without_weight(self, name, input_shape, normalized_shape):
"""
Test RMS norm without weight parameter (weight=None).
This ensures the converter handles optional weight correctly.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._fused_rms_norm.default(
x, normalized_shape, None, 1e-5
)[0]

inputs = [torch.randn(input_shape)]
self.run_test(
RMSNorm(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)

@parameterized.expand(
[
# Test different epsilon values
("eps_1e5", (2, 4, 8), [8], 1e-5),
("eps_1e6", (2, 4, 8), [8], 1e-6),
("eps_1e4", (2, 4, 8), [8], 1e-4),
]
)
def test_rms_norm_different_eps(self, name, input_shape, normalized_shape, eps):
"""
Test RMS norm with different epsilon values.
Epsilon is critical for numerical stability, especially with small values.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten._fused_rms_norm.default(
x, normalized_shape, weight, eps
)[0]

inputs = [
torch.randn(input_shape),
torch.randn(normalized_shape),
]
self.run_test(
RMSNorm(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)

def test_rms_norm_with_dynamic_shape_batch(self):
"""
Test RMS norm with dynamic batch dimension.
This is common in inference scenarios where batch size varies.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten._fused_rms_norm.default(x, [128], weight, 1e-6)[0]

input_specs = [
Input(
shape=(-1, 128),
dtype=torch.float32,
shape_ranges=[((1, 128), (4, 128), (8, 128))],
),
Input(
shape=(128,),
dtype=torch.float32,
),
]

self.run_test_with_dynamic_shape(
RMSNorm(),
input_specs,
use_dynamo_tracer=True,
enable_passes=True,
)

def test_rms_norm_with_dynamic_shape_sequence(self):
"""
Test RMS norm with dynamic sequence length.
This is critical for transformer models with variable sequence lengths.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten._fused_rms_norm.default(x, [256], weight, 1e-5)[0]

input_specs = [
Input(
shape=(2, -1, 256),
dtype=torch.float32,
shape_ranges=[((2, 16, 256), (2, 64, 256), (2, 128, 256))],
),
Input(
shape=(256,),
dtype=torch.float32,
),
]

self.run_test_with_dynamic_shape(
RMSNorm(),
input_specs,
use_dynamo_tracer=True,
enable_passes=True,
)

def test_rms_norm_with_dynamic_shape_multi_dim(self):
"""
Test RMS norm with multiple dynamic dimensions.
Tests both batch and sequence length being dynamic simultaneously.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten._fused_rms_norm.default(x, [64], weight, 1e-6)[0]

input_specs = [
Input(
shape=(-1, -1, 64),
dtype=torch.float32,
shape_ranges=[((1, 8, 64), (4, 16, 64), (8, 32, 64))],
),
Input(
shape=(64,),
dtype=torch.float32,
),
]

self.run_test_with_dynamic_shape(
RMSNorm(),
input_specs,
use_dynamo_tracer=True,
enable_passes=True,
)

def test_rms_norm_2d_input(self):
"""
Test RMS norm with 2D input (batch, features).
Common in MLP layers or simple feedforward networks.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten._fused_rms_norm.default(x, [512], weight, 1e-5)[0]

inputs = [
torch.randn(32, 512),
torch.randn(512),
]
self.run_test(
RMSNorm(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)

def test_rms_norm_large_hidden_dim(self):
"""
Test RMS norm with larger hidden dimensions typical in modern LLMs.
Tests numerical stability and performance with realistic model sizes.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten._fused_rms_norm.default(x, [4096], weight, 1e-6)[
0
]

inputs = [
torch.randn(2, 8, 4096),
torch.randn(4096),
]
self.run_test(
RMSNorm(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)

def test_rms_norm_flux_pattern(self):
"""
Test RMS norm with pattern similar to FLUX and modern diffusion models.
This tests the actual use case that motivated the converter implementation.
"""

class RMSNorm(torch.nn.Module):
def forward(self, x, weight):
# FLUX-style: normalize over last dimension with small epsilon
normalized_shape = [x.shape[-1]]
return torch.ops.aten._fused_rms_norm.default(
x, normalized_shape, weight, 1e-6
)[0]

inputs = [
torch.randn(1, 16, 3072), # Typical FLUX dimensions
torch.randn(3072),
]
self.run_test(
RMSNorm(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)


if __name__ == "__main__":
run_tests()
Loading