Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 120 additions & 185 deletions tests/unit/model_bridge/supported_architectures/test_llama_adapter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
"""Unit tests for LlamaArchitectureAdapter.

Tests cover:
- Config flags set by the adapter
- Component mapping structure (bridge types and HF module names)
- Weight conversion key set and rearrange patterns
- GQA: n_key_value_heads propagates to K/V conversions only
- setup_component_testing rotary embedding wiring
- Weight conversion key set
- GQA support via n_key_value_heads
"""

from types import SimpleNamespace

import pytest

from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
from transformer_lens.conversion_utils.param_processing_conversion import (
ParamProcessingConversion,
)
from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.model_bridge.generalized_components import (
BlockBridge,
EmbeddingBridge,
Expand All @@ -34,42 +28,119 @@
# Helpers / fixtures
# ---------------------------------------------------------------------------

N_HEADS = 8
N_KV_HEADS = 2
D_MODEL = 64
D_MLP = 256
N_LAYERS = 2
N_CTX = 256
D_VOCAB = 1000


def _make_cfg(
n_heads: int = 32,
d_model: int = 4096,
n_layers: int = 32,
d_vocab: int = 32000,
n_ctx: int = 4096,
**overrides,
n_heads: int = N_HEADS,
n_kv_heads: int | None = None,
d_model: int = D_MODEL,
n_layers: int = N_LAYERS,
d_mlp: int = D_MLP,
d_vocab: int = D_VOCAB,
n_ctx: int = N_CTX,
) -> TransformerBridgeConfig:
"""Return a minimal TransformerBridgeConfig for LLaMA adapter tests."""
cfg = TransformerBridgeConfig(
"""Return a minimal TransformerBridgeConfig for Llama adapter tests."""
kwargs = dict(
d_model=d_model,
d_head=d_model // n_heads,
n_heads=n_heads,
n_layers=n_layers,
n_ctx=n_ctx,
n_heads=n_heads,
d_vocab=d_vocab,
d_mlp=d_mlp,
default_prepend_bos=False,
architecture="LlamaForCausalLM",
)
for k, v in overrides.items():
setattr(cfg, k, v)
return cfg
if n_kv_heads is not None:
kwargs["n_key_value_heads"] = n_kv_heads
return TransformerBridgeConfig(**kwargs)


@pytest.fixture
def cfg() -> TransformerBridgeConfig:
return _make_cfg()

@pytest.fixture(scope="module")
def adapter() -> LlamaArchitectureAdapter:
return LlamaArchitectureAdapter(_make_cfg())

@pytest.fixture
def adapter(cfg: TransformerBridgeConfig) -> LlamaArchitectureAdapter:
return LlamaArchitectureAdapter(cfg)


@pytest.fixture
def gqa_cfg() -> TransformerBridgeConfig:
return _make_cfg(n_kv_heads=N_KV_HEADS)


@pytest.fixture
def gqa_adapter(gqa_cfg: TransformerBridgeConfig) -> LlamaArchitectureAdapter:
return LlamaArchitectureAdapter(gqa_cfg)


# ---------------------------------------------------------------------------
# Component mapping — top-level key set and bridge types
# Config flag tests
# ---------------------------------------------------------------------------


class TestLlamaComponentMapping:
"""Component mapping has the correct slots, bridge types, and HF module paths."""
class TestLlamaAdapterConfig:
"""Tests that the adapter sets the correct config flags."""

def test_normalization_type(self, adapter: LlamaArchitectureAdapter) -> None:
assert adapter.cfg.normalization_type == "RMS"

def test_positional_embedding_type(self, adapter: LlamaArchitectureAdapter) -> None:
assert adapter.cfg.positional_embedding_type == "rotary"

def test_final_rms(self, adapter: LlamaArchitectureAdapter) -> None:
"""Llama uses RMSNorm as the final norm (final_rms=True)."""
assert adapter.cfg.final_rms is True

def test_gated_mlp(self, adapter: LlamaArchitectureAdapter) -> None:
assert adapter.cfg.gated_mlp is True

def test_attn_only_false(self, adapter: LlamaArchitectureAdapter) -> None:
assert adapter.cfg.attn_only is False

def test_uses_rms_norm(self, adapter: LlamaArchitectureAdapter) -> None:
assert adapter.cfg.uses_rms_norm is True


# ---------------------------------------------------------------------------
# GQA config tests
# ---------------------------------------------------------------------------


class TestLlamaAdapterGQA:
"""Tests for GQA (Grouped Query Attention) via n_key_value_heads."""

def test_n_key_value_heads_propagated_to_cfg(
self, gqa_adapter: LlamaArchitectureAdapter
) -> None:
assert gqa_adapter.cfg.n_key_value_heads == N_KV_HEADS

def test_n_key_value_heads_in_default_config(
self, gqa_adapter: LlamaArchitectureAdapter
) -> None:
assert gqa_adapter.default_config["n_key_value_heads"] == N_KV_HEADS

def test_no_n_key_value_heads_without_gqa(self, adapter: LlamaArchitectureAdapter) -> None:
"""When n_key_value_heads is not set, default_config should not include it."""
assert "n_key_value_heads" not in adapter.default_config


# ---------------------------------------------------------------------------
# Component mapping tests
# ---------------------------------------------------------------------------


class TestLlamaAdapterComponentMapping:
"""Tests that component_mapping has the correct bridge types and HF module names."""

def test_top_level_keys(self, adapter: LlamaArchitectureAdapter) -> None:
assert set(adapter.component_mapping.keys()) == {
Expand All @@ -80,10 +151,6 @@ def test_top_level_keys(self, adapter: LlamaArchitectureAdapter) -> None:
"unembed",
}

def test_no_pos_embed_key(self, adapter: LlamaArchitectureAdapter) -> None:
"""LLaMA uses rotary embeddings — no learned positional embedding component."""
assert "pos_embed" not in adapter.component_mapping

def test_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert isinstance(mapping["embed"], EmbeddingBridge)
Expand All @@ -104,75 +171,60 @@ def test_block_submodule_keys(self, adapter: LlamaArchitectureAdapter) -> None:
blocks = adapter.component_mapping["blocks"]
assert set(blocks.submodules.keys()) == {"ln1", "ln2", "attn", "mlp"}

def test_block_submodule_types(self, adapter: LlamaArchitectureAdapter) -> None:
def test_block_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None:
blocks = adapter.component_mapping["blocks"]
assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge)
assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge)
assert isinstance(blocks.submodules["attn"], PositionEmbeddingsAttentionBridge)
assert isinstance(blocks.submodules["mlp"], GatedMLPBridge)

def test_block_submodule_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None:
def test_block_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None:
blocks = adapter.component_mapping["blocks"]
assert blocks.submodules["ln1"].name == "input_layernorm"
assert blocks.submodules["ln2"].name == "post_attention_layernorm"
assert blocks.submodules["attn"].name == "self_attn"
assert blocks.submodules["mlp"].name == "mlp"

def test_attn_requires_mask_and_position_embeddings(
self, adapter: LlamaArchitectureAdapter
) -> None:
"""LLaMA RoPE attention requires both attention mask and position embeddings."""
def test_attention_submodule_keys(self, adapter: LlamaArchitectureAdapter) -> None:
"""Llama uses separate Q, K, V, O projections."""
attn = adapter.component_mapping["blocks"].submodules["attn"]
assert attn.requires_attention_mask is True
assert attn.requires_position_embeddings is True
assert set(attn.submodules.keys()) == {"q", "k", "v", "o"}

def test_attn_qkvo_submodule_paths(self, adapter: LlamaArchitectureAdapter) -> None:
def test_attention_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None:
attn = adapter.component_mapping["blocks"].submodules["attn"]
assert set(attn.submodules.keys()) == {"q", "k", "v", "o"}
assert attn.submodules["q"].name == "q_proj"
assert attn.submodules["k"].name == "k_proj"
assert attn.submodules["v"].name == "v_proj"
assert attn.submodules["o"].name == "o_proj"

def test_attn_qkvo_are_linear_bridges(self, adapter: LlamaArchitectureAdapter) -> None:
attn = adapter.component_mapping["blocks"].submodules["attn"]
for sub in attn.submodules.values():
assert isinstance(sub, LinearBridge)

def test_mlp_submodule_paths(self, adapter: LlamaArchitectureAdapter) -> None:
def test_mlp_submodule_keys(self, adapter: LlamaArchitectureAdapter) -> None:
mlp = adapter.component_mapping["blocks"].submodules["mlp"]
assert set(mlp.submodules.keys()) == {"gate", "in", "out"}

def test_mlp_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None:
mlp = adapter.component_mapping["blocks"].submodules["mlp"]
assert mlp.submodules["gate"].name == "gate_proj"
assert mlp.submodules["in"].name == "up_proj"
assert mlp.submodules["out"].name == "down_proj"

def test_attention_linear_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None:
attn = adapter.component_mapping["blocks"].submodules["attn"]
for submodule in attn.submodules.values():
assert isinstance(submodule, LinearBridge)

# ---------------------------------------------------------------------------
# Anti-drift config flags
# ---------------------------------------------------------------------------


class TestLlamaAdapterConfig:
"""Anti-drift flags that must not silently regress."""

def test_final_rms_is_true(self, adapter: LlamaArchitectureAdapter) -> None:
assert adapter.cfg.final_rms is True

def test_uses_rms_norm_is_true(self, adapter: LlamaArchitectureAdapter) -> None:
assert adapter.cfg.uses_rms_norm is True

def test_gated_mlp_is_true(self, adapter: LlamaArchitectureAdapter) -> None:
"""LLaMA uses a gated SwiGLU MLP — must not silently revert to vanilla MLP."""
assert adapter.cfg.gated_mlp is True
def test_mlp_linear_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None:
mlp = adapter.component_mapping["blocks"].submodules["mlp"]
for submodule in mlp.submodules.values():
assert isinstance(submodule, LinearBridge)


# ---------------------------------------------------------------------------
# Weight processing conversions
# Weight conversion key tests
# ---------------------------------------------------------------------------


class TestLlamaWeightConversions:
"""weight_processing_conversions has exactly the expected QKVO keys."""
class TestLlamaAdapterWeightConversions:
"""Tests that weight_processing_conversions has exactly the expected keys."""

def test_exact_conversion_key_set(self, adapter: LlamaArchitectureAdapter) -> None:
assert set(adapter.weight_processing_conversions.keys()) == {
Expand All @@ -181,120 +233,3 @@ def test_exact_conversion_key_set(self, adapter: LlamaArchitectureAdapter) -> No
"blocks.{i}.attn.v.weight",
"blocks.{i}.attn.o.weight",
}

def test_qkv_conversions_use_split_heads_pattern(
self, adapter: LlamaArchitectureAdapter
) -> None:
"""'(n h) m -> n m h' splits [n_heads*d_head, d_model] → [n, d_model, d_head]."""
for slot in ("q", "k", "v"):
conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"]
assert isinstance(conv, ParamProcessingConversion)
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
assert conv.tensor_conversion.pattern == "(n h) m -> n m h"

def test_o_conversion_uses_merge_heads_pattern(self, adapter: LlamaArchitectureAdapter) -> None:
"""'m (n h) -> n h m' moves n to the front for the output projection."""
conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"]
assert isinstance(conv, ParamProcessingConversion)
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
assert conv.tensor_conversion.pattern == "m (n h) -> n h m"

def test_no_bias_conversion_keys(self, adapter: LlamaArchitectureAdapter) -> None:
"""LLaMA has no attention or MLP biases — no b_Q/b_K/b_V/b_O conversions."""
keys = set(adapter.weight_processing_conversions.keys())
assert not any("bias" in k or ".b_" in k for k in keys)

def test_no_norm_conversion_keys(self, adapter: LlamaArchitectureAdapter) -> None:
"""RMSNorm has no bias offset — no ln1/ln2/ln_final conversion entries."""
keys = set(adapter.weight_processing_conversions.keys())
assert not any("ln" in k for k in keys)


# ---------------------------------------------------------------------------
# GQA support — LLaMA 3.1 / 3.2 / 3.3
# ---------------------------------------------------------------------------


class TestLlamaGQASupport:
"""n_key_value_heads must propagate to K/V conversions and leave Q/O unchanged."""

def test_no_gqa_defaults_to_n_heads(self) -> None:
"""Without n_key_value_heads, K/V use n_heads (MHA mode)."""
adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32))
k_conv = adapter.weight_processing_conversions["blocks.{i}.attn.k.weight"]
assert k_conv.tensor_conversion.axes_lengths["n"] == 32

def test_gqa_propagates_to_kv_conversions(self) -> None:
"""With 8 KV heads (LLaMA-3 style), K/V conversions must use n=8."""
adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8))
for slot in ("k", "v"):
conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"]
assert conv.tensor_conversion.axes_lengths["n"] == 8

def test_gqa_does_not_affect_q_conversion(self) -> None:
"""Q always uses full n_heads regardless of GQA."""
adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8))
q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"]
assert q_conv.tensor_conversion.axes_lengths["n"] == 32

def test_gqa_does_not_affect_o_conversion(self) -> None:
"""O projection always uses n_heads; GQA only affects K/V."""
adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8))
o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"]
assert o_conv.tensor_conversion.axes_lengths["n"] == 32


# ---------------------------------------------------------------------------
# setup_component_testing — rotary embedding wiring
# ---------------------------------------------------------------------------


class _DummyAttn:
def __init__(self) -> None:
self.rotary_emb = None

def set_rotary_emb(self, rotary_emb: object) -> None:
self.rotary_emb = rotary_emb


class _DummyBlock:
def __init__(self, has_attn: bool = True) -> None:
if has_attn:
self.attn = _DummyAttn()


class _DummyBridgeModel:
def __init__(self, blocks: list) -> None:
self.blocks = blocks


def _fake_hf_model(rotary_emb: object) -> SimpleNamespace:
return SimpleNamespace(model=SimpleNamespace(rotary_emb=rotary_emb))


class TestLlamaSetupComponentTesting:
"""setup_component_testing wires rotary_emb onto every block's attention bridge."""

def test_sets_rotary_emb_on_all_blocks(self) -> None:
adapter = LlamaArchitectureAdapter(_make_cfg())
rotary_emb = object()
bridge_model = _DummyBridgeModel([_DummyBlock(), _DummyBlock(), _DummyBlock()])

adapter.setup_component_testing(_fake_hf_model(rotary_emb), bridge_model=bridge_model)

for block in bridge_model.blocks:
assert block.attn.rotary_emb is rotary_emb

def test_skips_blocks_without_attn(self) -> None:
adapter = LlamaArchitectureAdapter(_make_cfg())
rotary_emb = object()
bridge_model = _DummyBridgeModel([_DummyBlock(), _DummyBlock(has_attn=False)])

adapter.setup_component_testing(_fake_hf_model(rotary_emb), bridge_model=bridge_model)

assert bridge_model.blocks[0].attn.rotary_emb is rotary_emb

def test_no_bridge_model_does_not_raise(self) -> None:
"""setup_component_testing without a bridge_model must not raise."""
adapter = LlamaArchitectureAdapter(_make_cfg())
adapter.setup_component_testing(_fake_hf_model(object()))