diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d6dff173072..56891d48d3c 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated +from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated, ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope @@ -375,6 +375,9 @@ def __init__( self.qk_norm_before_rope = args.qk_norm_before_rope self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + self.scale_query_by = args.scale_query_by + self.use_attn_o_gate = args.use_attn_o_gate + self.use_attn_o_norm = args.use_attn_o_norm q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) # YOCO: Determine if this is a KV shared layer (receives shared KV from donor). @@ -417,17 +420,26 @@ def __init__( def _init_norms(self, args: ModelArgs) -> None: """Initialize QK normalization layers.""" if self.use_qk_norm: - self.q_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - if self.has_kv_weights: - self.k_norm_fn = RMSNorm( + if args.qk_norm_affine: + self.q_norm_fn = RMSNorm( self.head_dim, eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) + if self.has_kv_weights: + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + else: + self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.has_kv_weights: + self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_norm: + self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_gate: + self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False) def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None: """Initialize Q/K/V/O projection layers.""" @@ -477,14 +489,14 @@ def _prepare_qkv_shared( k, v = shared_kv if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by # Apply RoPE to Q only (K already has RoPE from donor layer) q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin) q = q.transpose(1, 2) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by return q, k, v @@ -507,7 +519,7 @@ def _prepare_qkv( v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by k = self.k_norm_fn(k) q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) @@ -517,7 +529,7 @@ def _prepare_qkv( v = v.transpose(1, 2) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm_fn(q) + q = self.q_norm_fn(q) * self.scale_query_by k = self.k_norm_fn(k) return q, k, v @@ -582,8 +594,7 @@ def forward( ) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) if shared_kv is None and self.num_kv_shared_layers > 0: update = {"kv_to_share": (k, v)} @@ -602,13 +613,27 @@ def forward( output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) output = self.wo(output) return output, None + def _apply_output_transforms( + self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int + ) -> torch.Tensor: + if self.use_attn_o_norm or self.use_attn_o_gate: + output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim) + if self.use_attn_o_norm: + output_4d = self.o_norm(output_4d) + if self.use_attn_o_gate: + og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + output_4d = torch.sigmoid(og) * output_4d + output = output_4d.reshape(bsz, seqlen, -1) + if gate is not None: + output = output * torch.sigmoid(gate) + return output + def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index e74ae810a02..921c6986af1 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,6 +7,7 @@ # Please refer to README.md in the same folder for more information. +import math from typing import Any, Dict, Optional, Tuple, Union import torch @@ -20,7 +21,7 @@ ) from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import RMSNorm, RMSNormWithInputScale, ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope from torch import nn @@ -51,6 +52,23 @@ def _is_kv_shared_layer( return layer_idx >= first_shared and first_shared > 0 +class NormPreservingResidualConnection(nn.Module): + def __init__(self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3): + super().__init__() + self.eps = eps + self.temperature = temperature + p = max(0.0 + eps, min(1.0 - eps, init_scale)) + init_param = math.log(p / (1.0 - p)) * temperature + self.gate = nn.Parameter(torch.full((dim,), init_param)) + + def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor: + w = self.gate.view(*([1] * (stream.ndim - 1)), -1) + beta = torch.sigmoid(w / self.temperature) + alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta) + alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps)) + return alpha * stream + beta * branch + + class ConditionalFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -99,7 +117,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): def __init__( - self, args: ModelArgs, attention: Attention, mlp_type: str = "default" + self, args: ModelArgs, attention: Attention, mlp_type: str = "default", + layer_id: int = 0, ): """ Transformer block with support for pre-norm and post-norm. @@ -110,6 +129,7 @@ def __init__( the attention type is registered in the ATTENTION_REGISTRY. mlp_type (str): MLP type for this layer. "default" for standard FFN, "skip" for no FFN block. + layer_id (int): layer index, used for residual gate initialization. """ super().__init__() self.use_kv_cache = args.use_kv_cache @@ -118,6 +138,7 @@ def __init__( self.head_dim = args.head_dim self.attention = attention self.mlp_type = mlp_type.lower() + self.use_residual_gate = args.use_residual_gate assert ( args.hidden_dim is not None @@ -150,6 +171,16 @@ def __init__( add_unit_offset=args.rms_norm_add_unit_offset, ) + if args.use_residual_gate: + attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5 + ffn_init = 1.0 / (2 * layer_id + 2) + self.add_attn = NormPreservingResidualConnection(dim=args.dim, init_scale=attn_init) + self.add_ffn = NormPreservingResidualConnection(dim=args.dim, init_scale=ffn_init) + self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps) + + if args.use_ffn_learnable_scales and self.mlp_type != "skip": + self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps) + @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": """ @@ -169,21 +200,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": mlp_type = args.mlp_type[layer_id] cls = ATTENTION_REGISTRY[args.attention_type] attention = cls(args, layer_id, rope, **args.attention_kwargs) - return TransformerBlock(args, attention, mlp_type=mlp_type) + return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) if not isinstance(self.attention, AttentionSkip): - h = x + h + if self.use_residual_gate: + if hasattr(self, "post_attn_norm"): + h = self.post_attn_norm(h) + h = self.add_attn(stream=x, branch=h) + else: + h = x + h if self.mlp_type == "skip": out = h elif hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) + ffn_out = self.block_sparse_moe(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out else: - out = h + self.feed_forward(self.ffn_norm(h)) + ffn_out = self.feed_forward(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out return out, attn_options_update @@ -371,7 +419,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: and model_args.layer_types[layer_id] == "skip_attention" ): attention = AttentionSkip() - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id) layers.append(transformer_block) elif ( model_args.layer_types @@ -386,13 +434,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = linear_cls( model_args, layer_id, rope, **model_args.attention_kwargs ) - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id) layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs ) # pyre-ignore[45] - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id) layers.append(transformer_block) return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 104e9fe2ddd..a132d81f408 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -102,6 +102,7 @@ class ModelArgs: apply_output: bool = True # Use output layer (unembedding) inside the transformer use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm + qk_norm_affine: bool = True # whether QK norm has learnable weight (False = scaleless) residual_multiplier: Optional[float] = ( None # Scaling factor applied to the residual hidden states ) @@ -162,6 +163,15 @@ class ModelArgs: final_logit_softcapping: Optional[float] = None attn_logit_softcapping: Optional[float] = None + # rlformers forward-pass features for on-device model parity + normalize_tok_embeddings: bool = False + scale_query_by: float = 1.0 + use_attn_o_gate: bool = False + use_attn_o_norm: bool = False + use_residual_gate: bool = False + use_ffn_learnable_scales: bool = False + output_soft_cap_temp: Optional[float] = None + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 0189c88b13b..03efac81671 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -32,33 +32,35 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - Args: - x (torch.Tensor): The input tensor. + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + return output * (1.0 + self.weight.float()).type_as(x) + return output * self.weight.type_as(x) - Returns: - torch.Tensor: The normalized tensor. - """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) +class ScalelessRMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps def forward(self, x): - """ - Forward pass through the RMSNorm layer. + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The output tensor after applying RMSNorm. +class RMSNormWithInputScale(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.dim = dim + self.weight = torch.nn.Parameter(torch.ones(dim)) - """ - output = self._norm(x.float()).type_as(x) - if self.add_unit_offset: - return output * (1.0 + self.weight.float()).type_as(x) - return output * self.weight.type_as(x) + def forward(self, x): + scaled = self.weight * x + return scaled * torch.rsqrt((scaled * scaled).mean(-1, keepdim=True) + self.eps) class RMSNormGated(nn.Module):