Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
Expand All @@ -7,7 +7,7 @@
import torch.nn.functional as F
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated, ScalelessRMSNorm
from executorch.examples.models.llama.rope import Rope


Expand Down Expand Up @@ -375,6 +375,9 @@
self.qk_norm_before_rope = args.qk_norm_before_rope
self.use_q_gate = args.use_q_gate
self.enable_dynamic_shape = args.enable_dynamic_shape
self.scale_query_by = args.scale_query_by
self.use_attn_o_gate = args.use_attn_o_gate
self.use_attn_o_norm = args.use_attn_o_norm
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)

# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
Expand Down Expand Up @@ -417,17 +420,26 @@
def _init_norms(self, args: ModelArgs) -> None:
"""Initialize QK normalization layers."""
if self.use_qk_norm:
self.q_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.has_kv_weights:
self.k_norm_fn = RMSNorm(
if args.qk_norm_affine:
self.q_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.has_kv_weights:
self.k_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
else:
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
if self.has_kv_weights:
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
if self.use_attn_o_norm:
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
if self.use_attn_o_gate:
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)

def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
"""Initialize Q/K/V/O projection layers."""
Expand Down Expand Up @@ -477,14 +489,14 @@
k, v = shared_kv

if self.use_qk_norm and self.qk_norm_before_rope:
q = self.q_norm_fn(q)
q = self.q_norm_fn(q) * self.scale_query_by

# Apply RoPE to Q only (K already has RoPE from donor layer)
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
q = q.transpose(1, 2)

if self.use_qk_norm and not self.qk_norm_before_rope:
q = self.q_norm_fn(q)
q = self.q_norm_fn(q) * self.scale_query_by

return q, k, v

Expand All @@ -507,7 +519,7 @@
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

if self.use_qk_norm and self.qk_norm_before_rope:
q = self.q_norm_fn(q)
q = self.q_norm_fn(q) * self.scale_query_by
k = self.k_norm_fn(k)

q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
Expand All @@ -517,7 +529,7 @@
v = v.transpose(1, 2)

if self.use_qk_norm and not self.qk_norm_before_rope:
q = self.q_norm_fn(q)
q = self.q_norm_fn(q) * self.scale_query_by
k = self.k_norm_fn(k)

return q, k, v
Expand Down Expand Up @@ -582,8 +594,7 @@
)

output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
if gate is not None:
output = output * torch.sigmoid(gate)
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)

if shared_kv is None and self.num_kv_shared_layers > 0:
update = {"kv_to_share": (k, v)}
Expand All @@ -602,13 +613,27 @@
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
if gate is not None:
output = output * torch.sigmoid(gate)
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)

output = self.wo(output)

return output, None

def _apply_output_transforms(
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
) -> torch.Tensor:
if self.use_attn_o_norm or self.use_attn_o_gate:
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
if self.use_attn_o_norm:
output_4d = self.o_norm(output_4d)
if self.use_attn_o_gate:
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
output_4d = torch.sigmoid(og) * output_4d
output = output_4d.reshape(bsz, seqlen, -1)
if gate is not None:
output = output * torch.sigmoid(gate)
return output


def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
Expand Down
66 changes: 57 additions & 9 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# @lint-ignore-every LICENSELINT
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
Expand All @@ -7,6 +7,7 @@

# Please refer to README.md in the same folder for more information.

import math
from typing import Any, Dict, Optional, Tuple, Union

import torch
Expand All @@ -20,7 +21,7 @@
)
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.norm import RMSNorm, RMSNormWithInputScale, ScalelessRMSNorm
from executorch.examples.models.llama.rope import Rope
from torch import nn

Expand Down Expand Up @@ -51,6 +52,23 @@
return layer_idx >= first_shared and first_shared > 0


class NormPreservingResidualConnection(nn.Module):
def __init__(self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3):
super().__init__()
self.eps = eps
self.temperature = temperature
p = max(0.0 + eps, min(1.0 - eps, init_scale))
init_param = math.log(p / (1.0 - p)) * temperature
self.gate = nn.Parameter(torch.full((dim,), init_param))

def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
w = self.gate.view(*([1] * (stream.ndim - 1)), -1)
beta = torch.sigmoid(w / self.temperature)
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
return alpha * stream + beta * branch


class ConditionalFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
Expand Down Expand Up @@ -99,7 +117,8 @@

class TransformerBlock(nn.Module):
def __init__(
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
self, args: ModelArgs, attention: Attention, mlp_type: str = "default",
layer_id: int = 0,
):
"""
Transformer block with support for pre-norm and post-norm.
Expand All @@ -110,6 +129,7 @@
the attention type is registered in the ATTENTION_REGISTRY.
mlp_type (str): MLP type for this layer. "default" for standard
FFN, "skip" for no FFN block.
layer_id (int): layer index, used for residual gate initialization.
"""
super().__init__()
self.use_kv_cache = args.use_kv_cache
Expand All @@ -118,6 +138,7 @@
self.head_dim = args.head_dim
self.attention = attention
self.mlp_type = mlp_type.lower()
self.use_residual_gate = args.use_residual_gate

assert (
args.hidden_dim is not None
Expand Down Expand Up @@ -150,6 +171,16 @@
add_unit_offset=args.rms_norm_add_unit_offset,
)

if args.use_residual_gate:
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
ffn_init = 1.0 / (2 * layer_id + 2)
self.add_attn = NormPreservingResidualConnection(dim=args.dim, init_scale=attn_init)
self.add_ffn = NormPreservingResidualConnection(dim=args.dim, init_scale=ffn_init)
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)

if args.use_ffn_learnable_scales and self.mlp_type != "skip":
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)

@classmethod
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
"""
Expand All @@ -169,21 +200,38 @@
mlp_type = args.mlp_type[layer_id]
cls = ATTENTION_REGISTRY[args.attention_type]
attention = cls(args, layer_id, rope, **args.attention_kwargs)
return TransformerBlock(args, attention, mlp_type=mlp_type)
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)

def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
h, attn_options_update = self.attention(
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
)
if not isinstance(self.attention, AttentionSkip):
h = x + h
if self.use_residual_gate:
if hasattr(self, "post_attn_norm"):
h = self.post_attn_norm(h)
h = self.add_attn(stream=x, branch=h)
else:
h = x + h

if self.mlp_type == "skip":
out = h
elif hasattr(self, "block_sparse_moe"):
out = h + self.block_sparse_moe(self.ffn_norm(h))
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
if hasattr(self, "post_ffn_norm"):
ffn_out = self.post_ffn_norm(ffn_out)
if self.use_residual_gate:
out = self.add_ffn(stream=h, branch=ffn_out)
else:
out = h + ffn_out
else:
out = h + self.feed_forward(self.ffn_norm(h))
ffn_out = self.feed_forward(self.ffn_norm(h))
if hasattr(self, "post_ffn_norm"):
ffn_out = self.post_ffn_norm(ffn_out)
if self.use_residual_gate:
out = self.add_ffn(stream=h, branch=ffn_out)
else:
out = h + ffn_out
return out, attn_options_update


Expand Down Expand Up @@ -371,7 +419,7 @@
and model_args.layer_types[layer_id] == "skip_attention"
):
attention = AttentionSkip()
transformer_block = TransformerBlock(model_args, attention)
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
layers.append(transformer_block)
elif (
model_args.layer_types
Expand All @@ -386,13 +434,13 @@
attention = linear_cls(
model_args, layer_id, rope, **model_args.attention_kwargs
)
transformer_block = TransformerBlock(model_args, attention)
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
layers.append(transformer_block)
else:
attention = cls(
model_args, layer_id, rope, **model_args.attention_kwargs
) # pyre-ignore[45]
transformer_block = TransformerBlock(model_args, attention)
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
layers.append(transformer_block)

return Transformer(model_args, layers, rope)
10 changes: 10 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import dataclasses
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -102,6 +102,7 @@
apply_output: bool = True # Use output layer (unembedding) inside the transformer
use_qk_norm: bool = False # apply normalization to q and k in the attention
qk_norm_before_rope: bool = False # when to apply qk norm
qk_norm_affine: bool = True # whether QK norm has learnable weight (False = scaleless)
residual_multiplier: Optional[float] = (
None # Scaling factor applied to the residual hidden states
)
Expand Down Expand Up @@ -162,6 +163,15 @@
final_logit_softcapping: Optional[float] = None
attn_logit_softcapping: Optional[float] = None

# rlformers forward-pass features for on-device model parity
normalize_tok_embeddings: bool = False
scale_query_by: float = 1.0
use_attn_o_gate: bool = False
use_attn_o_norm: bool = False
use_residual_gate: bool = False
use_ffn_learnable_scales: bool = False
output_soft_cap_temp: Optional[float] = None

def __post_init__(self): # noqa: C901
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand Down
40 changes: 21 additions & 19 deletions examples/models/llama/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,35 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)

Args:
x (torch.Tensor): The input tensor.
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight.type_as(x)

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
class ScalelessRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps

def forward(self, x):
"""
Forward pass through the RMSNorm layer.
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.
class RMSNormWithInputScale(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.dim = dim
self.weight = torch.nn.Parameter(torch.ones(dim))

"""
output = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight.type_as(x)
def forward(self, x):
scaled = self.weight * x
return scaled * torch.rsqrt((scaled * scaled).mean(-1, keepdim=True) + self.eps)


class RMSNormGated(nn.Module):
Expand Down
Loading