Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,34 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(
torch.ops.aten._fused_rms_norm.default,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_fused_rms_norm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.fused_rms_norm(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
normalized_shape=args[1],
weight=args_bounds_check(args, 2),
eps=args_bounds_check(args, 3),
)


def parse_cat_args(
args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Tuple[List[Any], int]:
Expand Down
86 changes: 86 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,3 +854,89 @@ def cdist_forward(
return_indices=False,
)
return dist


def fused_rms_norm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: trt.ITensor,
normalized_shape: List[int],
weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
eps: Optional[float],
) -> Tuple[trt.ITensor, torch.Tensor]:
"""
RMS Normalization: output = input / sqrt(mean(input^2) + eps) * weight

Args:
ctx: ConversionContext containing the TensorRT network
target: Target of calling node
source_ir: SourceIR of calling converter
name: Name of the calling layer
input: Input tensor to normalize
normalized_shape: Shape over which to normalize (list of ints)
weight: Optional weight/scale parameter
eps: Epsilon for numerical stability (default: 1e-5)

Returns:
Tuple of (normalized_output, rstd_placeholder)
Note: rstd (reciprocal standard deviation) is returned as None placeholder
"""
if eps is None:
eps = 1e-5

# Calculate dimensions to normalize over (similar to layer_norm)
# normalized_shape specifies the last N dimensions
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
axes = get_axes_for_reduce_op(dims)

# Square the input
input_squared = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_input_squared", input, input
)

# Compute mean of squared values
mean_squared = impl.reduce.mean(
ctx, target, source_ir, f"{name}_mean_squared", input_squared, dim=dims, keepdim=True
)

# Add epsilon for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", input.dtype)
mean_squared_eps = impl.elementwise.add(
ctx, target, source_ir, f"{name}_mean_squared_eps", mean_squared, eps_tensor
)

# Compute RMS = sqrt(mean(input^2) + eps)
rms = impl.unary.sqrt(ctx, target, source_ir, f"{name}_rms", mean_squared_eps)

# Normalize: input / rms
normalized = impl.elementwise.div(
ctx, target, source_ir, f"{name}_normalized", input, rms
)

# Apply weight (scale) if provided
if weight is not None:
weight_trt = get_trt_tensor(ctx, weight, f"{name}_weight")

# Cast weight to match input dtype
weight_trt = cast_trt_tensor(
ctx, weight_trt, input.dtype, f"{name}_weight_cast", target, source_ir
)

# Expand weight to match input shape if needed
if tuple(input.shape) != tuple(weight_trt.shape):
weight_trt = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight", weight_trt, input.shape
)

# Multiply normalized output by weight
output = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_output", normalized, weight_trt
)
else:
output = normalized

# Return (output, rstd_placeholder)
# PyTorch returns (output, rstd) but we return None for rstd as it's typically not used
return output, None
Loading
Loading