From 6d2a84a9fe3fa2b851d79964c74e79cf7eb080ef Mon Sep 17 00:00:00 2001 From: Igor Fedorov Date: Thu, 23 Apr 2026 16:25:51 -0700 Subject: [PATCH] Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The 730M dense model checkpoint uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output. This diff adds support for 8 missing features: 1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup 2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it) 3. `scale_query_by` — custom scalar multiplier on Q after QK norm 4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input 5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate) 6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections 7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)` 8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion. Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`. Differential Revision: D102030169 --- examples/models/llama/attention.py | 57 +++++++++++++------ examples/models/llama/llama_transformer.py | 66 +++++++++++++++++++--- examples/models/llama/model_args.py | 10 ++++ examples/models/llama/norm.py | 40 ++++++------- 4 files changed, 129 insertions(+), 44 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d6dff173072..56891d48d3c 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated +from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated, ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope @@ -375,6 +375,9 @@ def __init__( self.qk_norm_before_rope = args.qk_norm_before_rope self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + self.scale_query_by = args.scale_query_by + self.use_attn_o_gate = args.use_attn_o_gate + self.use_attn_o_norm = args.use_attn_o_norm q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) # YOCO: Determine if this is a KV shared layer (receives shared KV from donor). @@ -417,17 +420,26 @@ def __init__( def _init_norms(self, args: ModelArgs) -> None: """Initialize QK normalization layers.""" if self.use_qk_norm: - self.q_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - if self.has_kv_weights: - self.k_norm_fn = RMSNorm( + if args.qk_norm_affine: + self.q_norm_fn = RMSNorm( self.head_dim, eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) + if self.has_kv_weights: + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + else: + self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.has_kv_weights: + self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_norm: + self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_gate: + self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False) def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None: """Initialize Q/K/V/O projection layers.""" @@ -477,14 +489,14 @@ def _prepare_qkv_shared( k, v = shared_kv if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by # Apply RoPE to Q only (K already has RoPE from donor layer) q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin) q = q.transpose(1, 2) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by return q, k, v @@ -507,7 +519,7 @@ def _prepare_qkv( v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by k = self.k_norm_fn(k) q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) @@ -517,7 +529,7 @@ def _prepare_qkv( v = v.transpose(1, 2) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by k = self.k_norm_fn(k) return q, k, v @@ -582,8 +594,7 @@ def forward( ) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) if shared_kv is None and self.num_kv_shared_layers > 0: update = {"kv_to_share": (k, v)} @@ -602,13 +613,27 @@ def forward( output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) output = self.wo(output) return output, None + def _apply_output_transforms( + self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int + ) -> torch.Tensor: + if self.use_attn_o_norm or self.use_attn_o_gate: + output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim) + if self.use_attn_o_norm: + output_4d = self.o_norm(output_4d) + if self.use_attn_o_gate: + og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + output_4d = torch.sigmoid(og) * output_4d + output = output_4d.reshape(bsz, seqlen, -1) + if gate is not None: + output = output * torch.sigmoid(gate) + return output + def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index e74ae810a02..921c6986af1 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,6 +7,7 @@ # Please refer to README.md in the same folder for more information. +import math from typing import Any, Dict, Optional, Tuple, Union import torch @@ -20,7 +21,7 @@ ) from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import RMSNorm, RMSNormWithInputScale, ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope from torch import nn @@ -51,6 +52,23 @@ def _is_kv_shared_layer( return layer_idx >= first_shared and first_shared > 0 +class NormPreservingResidualConnection(nn.Module): + def __init__(self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3): + super().__init__() + self.eps = eps + self.temperature = temperature + p = max(0.0 + eps, min(1.0 - eps, init_scale)) + init_param = math.log(p / (1.0 - p)) * temperature + self.gate = nn.Parameter(torch.full((dim,), init_param)) + + def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor: + w = self.gate.view(*([1] * (stream.ndim - 1)), -1) + beta = torch.sigmoid(w / self.temperature) + alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta) + alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps)) + return alpha * stream + beta * branch + + class ConditionalFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -99,7 +117,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): def __init__( - self, args: ModelArgs, attention: Attention, mlp_type: str = "default" + self, args: ModelArgs, attention: Attention, mlp_type: str = "default", + layer_id: int = 0, ): """ Transformer block with support for pre-norm and post-norm. @@ -110,6 +129,7 @@ def __init__( the attention type is registered in the ATTENTION_REGISTRY. mlp_type (str): MLP type for this layer. "default" for standard FFN, "skip" for no FFN block. + layer_id (int): layer index, used for residual gate initialization. """ super().__init__() self.use_kv_cache = args.use_kv_cache @@ -118,6 +138,7 @@ def __init__( self.head_dim = args.head_dim self.attention = attention self.mlp_type = mlp_type.lower() + self.use_residual_gate = args.use_residual_gate assert ( args.hidden_dim is not None @@ -150,6 +171,16 @@ def __init__( add_unit_offset=args.rms_norm_add_unit_offset, ) + if args.use_residual_gate: + attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5 + ffn_init = 1.0 / (2 * layer_id + 2) + self.add_attn = NormPreservingResidualConnection(dim=args.dim, init_scale=attn_init) + self.add_ffn = NormPreservingResidualConnection(dim=args.dim, init_scale=ffn_init) + self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps) + + if args.use_ffn_learnable_scales and self.mlp_type != "skip": + self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps) + @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": """ @@ -169,21 +200,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": mlp_type = args.mlp_type[layer_id] cls = ATTENTION_REGISTRY[args.attention_type] attention = cls(args, layer_id, rope, **args.attention_kwargs) - return TransformerBlock(args, attention, mlp_type=mlp_type) + return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) if not isinstance(self.attention, AttentionSkip): - h = x + h + if self.use_residual_gate: + if hasattr(self, "post_attn_norm"): + h = self.post_attn_norm(h) + h = self.add_attn(stream=x, branch=h) + else: + h = x + h if self.mlp_type == "skip": out = h elif hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) + ffn_out = self.block_sparse_moe(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out else: - out = h + self.feed_forward(self.ffn_norm(h)) + ffn_out = self.feed_forward(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out return out, attn_options_update @@ -371,7 +419,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: and model_args.layer_types[layer_id] == "skip_attention" ): attention = AttentionSkip() - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id) layers.append(transformer_block) elif ( model_args.layer_types @@ -386,13 +434,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = linear_cls( model_args, layer_id, rope, **model_args.attention_kwargs ) - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id) layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs ) # pyre-ignore[45] - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id) layers.append(transformer_block) return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 104e9fe2ddd..a132d81f408 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -102,6 +102,7 @@ class ModelArgs: apply_output: bool = True # Use output layer (unembedding) inside the transformer use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm + qk_norm_affine: bool = True # whether QK norm has learnable weight (False = scaleless) residual_multiplier: Optional[float] = ( None # Scaling factor applied to the residual hidden states ) @@ -162,6 +163,15 @@ class ModelArgs: final_logit_softcapping: Optional[float] = None attn_logit_softcapping: Optional[float] = None + # rlformers forward-pass features for on-device model parity + normalize_tok_embeddings: bool = False + scale_query_by: float = 1.0 + use_attn_o_gate: bool = False + use_attn_o_norm: bool = False + use_residual_gate: bool = False + use_ffn_learnable_scales: bool = False + output_soft_cap_temp: Optional[float] = None + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 0189c88b13b..03efac81671 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -32,33 +32,35 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - Args: - x (torch.Tensor): The input tensor. + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + return output * (1.0 + self.weight.float()).type_as(x) + return output * self.weight.type_as(x) - Returns: - torch.Tensor: The normalized tensor. - """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) +class ScalelessRMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps def forward(self, x): - """ - Forward pass through the RMSNorm layer. + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The output tensor after applying RMSNorm. +class RMSNormWithInputScale(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.dim = dim + self.weight = torch.nn.Parameter(torch.ones(dim)) - """ - output = self._norm(x.float()).type_as(x) - if self.add_unit_offset: - return output * (1.0 + self.weight.float()).type_as(x) - return output * self.weight.type_as(x) + def forward(self, x): + scaled = self.weight * x + return scaled * torch.rsqrt((scaled * scaled).mean(-1, keepdim=True) + self.eps) class RMSNormGated(nn.Module):