diff --git a/tests/unit/model_bridge/supported_architectures/test_llama_adapter.py b/tests/unit/model_bridge/supported_architectures/test_llama_adapter.py index a1168a591..3fe0e8fa3 100644 --- a/tests/unit/model_bridge/supported_architectures/test_llama_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_llama_adapter.py @@ -1,21 +1,15 @@ """Unit tests for LlamaArchitectureAdapter. Tests cover: +- Config flags set by the adapter - Component mapping structure (bridge types and HF module names) -- Weight conversion key set and rearrange patterns -- GQA: n_key_value_heads propagates to K/V conversions only -- setup_component_testing rotary embedding wiring +- Weight conversion key set +- GQA support via n_key_value_heads """ -from types import SimpleNamespace - import pytest -from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig -from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion -from transformer_lens.conversion_utils.param_processing_conversion import ( - ParamProcessingConversion, -) +from transformer_lens.config import TransformerBridgeConfig from transformer_lens.model_bridge.generalized_components import ( BlockBridge, EmbeddingBridge, @@ -34,42 +28,119 @@ # Helpers / fixtures # --------------------------------------------------------------------------- +N_HEADS = 8 +N_KV_HEADS = 2 +D_MODEL = 64 +D_MLP = 256 +N_LAYERS = 2 +N_CTX = 256 +D_VOCAB = 1000 + def _make_cfg( - n_heads: int = 32, - d_model: int = 4096, - n_layers: int = 32, - d_vocab: int = 32000, - n_ctx: int = 4096, - **overrides, + n_heads: int = N_HEADS, + n_kv_heads: int | None = None, + d_model: int = D_MODEL, + n_layers: int = N_LAYERS, + d_mlp: int = D_MLP, + d_vocab: int = D_VOCAB, + n_ctx: int = N_CTX, ) -> TransformerBridgeConfig: - """Return a minimal TransformerBridgeConfig for LLaMA adapter tests.""" - cfg = TransformerBridgeConfig( + """Return a minimal TransformerBridgeConfig for Llama adapter tests.""" + kwargs = dict( d_model=d_model, d_head=d_model // n_heads, - n_heads=n_heads, n_layers=n_layers, n_ctx=n_ctx, + n_heads=n_heads, d_vocab=d_vocab, + d_mlp=d_mlp, + default_prepend_bos=False, architecture="LlamaForCausalLM", ) - for k, v in overrides.items(): - setattr(cfg, k, v) - return cfg + if n_kv_heads is not None: + kwargs["n_key_value_heads"] = n_kv_heads + return TransformerBridgeConfig(**kwargs) + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() -@pytest.fixture(scope="module") -def adapter() -> LlamaArchitectureAdapter: - return LlamaArchitectureAdapter(_make_cfg()) + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> LlamaArchitectureAdapter: + return LlamaArchitectureAdapter(cfg) + + +@pytest.fixture +def gqa_cfg() -> TransformerBridgeConfig: + return _make_cfg(n_kv_heads=N_KV_HEADS) + + +@pytest.fixture +def gqa_adapter(gqa_cfg: TransformerBridgeConfig) -> LlamaArchitectureAdapter: + return LlamaArchitectureAdapter(gqa_cfg) # --------------------------------------------------------------------------- -# Component mapping — top-level key set and bridge types +# Config flag tests # --------------------------------------------------------------------------- -class TestLlamaComponentMapping: - """Component mapping has the correct slots, bridge types, and HF module paths.""" +class TestLlamaAdapterConfig: + """Tests that the adapter sets the correct config flags.""" + + def test_normalization_type(self, adapter: LlamaArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter: LlamaArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter: LlamaArchitectureAdapter) -> None: + """Llama uses RMSNorm as the final norm (final_rms=True).""" + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter: LlamaArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_attn_only_false(self, adapter: LlamaArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm(self, adapter: LlamaArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + +# --------------------------------------------------------------------------- +# GQA config tests +# --------------------------------------------------------------------------- + + +class TestLlamaAdapterGQA: + """Tests for GQA (Grouped Query Attention) via n_key_value_heads.""" + + def test_n_key_value_heads_propagated_to_cfg( + self, gqa_adapter: LlamaArchitectureAdapter + ) -> None: + assert gqa_adapter.cfg.n_key_value_heads == N_KV_HEADS + + def test_n_key_value_heads_in_default_config( + self, gqa_adapter: LlamaArchitectureAdapter + ) -> None: + assert gqa_adapter.default_config["n_key_value_heads"] == N_KV_HEADS + + def test_no_n_key_value_heads_without_gqa(self, adapter: LlamaArchitectureAdapter) -> None: + """When n_key_value_heads is not set, default_config should not include it.""" + assert "n_key_value_heads" not in adapter.default_config + + +# --------------------------------------------------------------------------- +# Component mapping tests +# --------------------------------------------------------------------------- + + +class TestLlamaAdapterComponentMapping: + """Tests that component_mapping has the correct bridge types and HF module names.""" def test_top_level_keys(self, adapter: LlamaArchitectureAdapter) -> None: assert set(adapter.component_mapping.keys()) == { @@ -80,10 +151,6 @@ def test_top_level_keys(self, adapter: LlamaArchitectureAdapter) -> None: "unembed", } - def test_no_pos_embed_key(self, adapter: LlamaArchitectureAdapter) -> None: - """LLaMA uses rotary embeddings — no learned positional embedding component.""" - assert "pos_embed" not in adapter.component_mapping - def test_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None: mapping = adapter.component_mapping assert isinstance(mapping["embed"], EmbeddingBridge) @@ -104,75 +171,60 @@ def test_block_submodule_keys(self, adapter: LlamaArchitectureAdapter) -> None: blocks = adapter.component_mapping["blocks"] assert set(blocks.submodules.keys()) == {"ln1", "ln2", "attn", "mlp"} - def test_block_submodule_types(self, adapter: LlamaArchitectureAdapter) -> None: + def test_block_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None: blocks = adapter.component_mapping["blocks"] assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) assert isinstance(blocks.submodules["attn"], PositionEmbeddingsAttentionBridge) assert isinstance(blocks.submodules["mlp"], GatedMLPBridge) - def test_block_submodule_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None: + def test_block_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None: blocks = adapter.component_mapping["blocks"] assert blocks.submodules["ln1"].name == "input_layernorm" assert blocks.submodules["ln2"].name == "post_attention_layernorm" assert blocks.submodules["attn"].name == "self_attn" assert blocks.submodules["mlp"].name == "mlp" - def test_attn_requires_mask_and_position_embeddings( - self, adapter: LlamaArchitectureAdapter - ) -> None: - """LLaMA RoPE attention requires both attention mask and position embeddings.""" + def test_attention_submodule_keys(self, adapter: LlamaArchitectureAdapter) -> None: + """Llama uses separate Q, K, V, O projections.""" attn = adapter.component_mapping["blocks"].submodules["attn"] - assert attn.requires_attention_mask is True - assert attn.requires_position_embeddings is True + assert set(attn.submodules.keys()) == {"q", "k", "v", "o"} - def test_attn_qkvo_submodule_paths(self, adapter: LlamaArchitectureAdapter) -> None: + def test_attention_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None: attn = adapter.component_mapping["blocks"].submodules["attn"] - assert set(attn.submodules.keys()) == {"q", "k", "v", "o"} assert attn.submodules["q"].name == "q_proj" assert attn.submodules["k"].name == "k_proj" assert attn.submodules["v"].name == "v_proj" assert attn.submodules["o"].name == "o_proj" - def test_attn_qkvo_are_linear_bridges(self, adapter: LlamaArchitectureAdapter) -> None: - attn = adapter.component_mapping["blocks"].submodules["attn"] - for sub in attn.submodules.values(): - assert isinstance(sub, LinearBridge) - - def test_mlp_submodule_paths(self, adapter: LlamaArchitectureAdapter) -> None: + def test_mlp_submodule_keys(self, adapter: LlamaArchitectureAdapter) -> None: mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert set(mlp.submodules.keys()) == {"gate", "in", "out"} + + def test_mlp_hf_paths(self, adapter: LlamaArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert mlp.submodules["gate"].name == "gate_proj" assert mlp.submodules["in"].name == "up_proj" assert mlp.submodules["out"].name == "down_proj" + def test_attention_linear_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + for submodule in attn.submodules.values(): + assert isinstance(submodule, LinearBridge) -# --------------------------------------------------------------------------- -# Anti-drift config flags -# --------------------------------------------------------------------------- - - -class TestLlamaAdapterConfig: - """Anti-drift flags that must not silently regress.""" - - def test_final_rms_is_true(self, adapter: LlamaArchitectureAdapter) -> None: - assert adapter.cfg.final_rms is True - - def test_uses_rms_norm_is_true(self, adapter: LlamaArchitectureAdapter) -> None: - assert adapter.cfg.uses_rms_norm is True - - def test_gated_mlp_is_true(self, adapter: LlamaArchitectureAdapter) -> None: - """LLaMA uses a gated SwiGLU MLP — must not silently revert to vanilla MLP.""" - assert adapter.cfg.gated_mlp is True + def test_mlp_linear_bridge_types(self, adapter: LlamaArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + for submodule in mlp.submodules.values(): + assert isinstance(submodule, LinearBridge) # --------------------------------------------------------------------------- -# Weight processing conversions +# Weight conversion key tests # --------------------------------------------------------------------------- -class TestLlamaWeightConversions: - """weight_processing_conversions has exactly the expected QKVO keys.""" +class TestLlamaAdapterWeightConversions: + """Tests that weight_processing_conversions has exactly the expected keys.""" def test_exact_conversion_key_set(self, adapter: LlamaArchitectureAdapter) -> None: assert set(adapter.weight_processing_conversions.keys()) == { @@ -181,120 +233,3 @@ def test_exact_conversion_key_set(self, adapter: LlamaArchitectureAdapter) -> No "blocks.{i}.attn.v.weight", "blocks.{i}.attn.o.weight", } - - def test_qkv_conversions_use_split_heads_pattern( - self, adapter: LlamaArchitectureAdapter - ) -> None: - """'(n h) m -> n m h' splits [n_heads*d_head, d_model] → [n, d_model, d_head].""" - for slot in ("q", "k", "v"): - conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] - assert isinstance(conv, ParamProcessingConversion) - assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) - assert conv.tensor_conversion.pattern == "(n h) m -> n m h" - - def test_o_conversion_uses_merge_heads_pattern(self, adapter: LlamaArchitectureAdapter) -> None: - """'m (n h) -> n h m' moves n to the front for the output projection.""" - conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] - assert isinstance(conv, ParamProcessingConversion) - assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) - assert conv.tensor_conversion.pattern == "m (n h) -> n h m" - - def test_no_bias_conversion_keys(self, adapter: LlamaArchitectureAdapter) -> None: - """LLaMA has no attention or MLP biases — no b_Q/b_K/b_V/b_O conversions.""" - keys = set(adapter.weight_processing_conversions.keys()) - assert not any("bias" in k or ".b_" in k for k in keys) - - def test_no_norm_conversion_keys(self, adapter: LlamaArchitectureAdapter) -> None: - """RMSNorm has no bias offset — no ln1/ln2/ln_final conversion entries.""" - keys = set(adapter.weight_processing_conversions.keys()) - assert not any("ln" in k for k in keys) - - -# --------------------------------------------------------------------------- -# GQA support — LLaMA 3.1 / 3.2 / 3.3 -# --------------------------------------------------------------------------- - - -class TestLlamaGQASupport: - """n_key_value_heads must propagate to K/V conversions and leave Q/O unchanged.""" - - def test_no_gqa_defaults_to_n_heads(self) -> None: - """Without n_key_value_heads, K/V use n_heads (MHA mode).""" - adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32)) - k_conv = adapter.weight_processing_conversions["blocks.{i}.attn.k.weight"] - assert k_conv.tensor_conversion.axes_lengths["n"] == 32 - - def test_gqa_propagates_to_kv_conversions(self) -> None: - """With 8 KV heads (LLaMA-3 style), K/V conversions must use n=8.""" - adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) - for slot in ("k", "v"): - conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] - assert conv.tensor_conversion.axes_lengths["n"] == 8 - - def test_gqa_does_not_affect_q_conversion(self) -> None: - """Q always uses full n_heads regardless of GQA.""" - adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) - q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] - assert q_conv.tensor_conversion.axes_lengths["n"] == 32 - - def test_gqa_does_not_affect_o_conversion(self) -> None: - """O projection always uses n_heads; GQA only affects K/V.""" - adapter = LlamaArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) - o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] - assert o_conv.tensor_conversion.axes_lengths["n"] == 32 - - -# --------------------------------------------------------------------------- -# setup_component_testing — rotary embedding wiring -# --------------------------------------------------------------------------- - - -class _DummyAttn: - def __init__(self) -> None: - self.rotary_emb = None - - def set_rotary_emb(self, rotary_emb: object) -> None: - self.rotary_emb = rotary_emb - - -class _DummyBlock: - def __init__(self, has_attn: bool = True) -> None: - if has_attn: - self.attn = _DummyAttn() - - -class _DummyBridgeModel: - def __init__(self, blocks: list) -> None: - self.blocks = blocks - - -def _fake_hf_model(rotary_emb: object) -> SimpleNamespace: - return SimpleNamespace(model=SimpleNamespace(rotary_emb=rotary_emb)) - - -class TestLlamaSetupComponentTesting: - """setup_component_testing wires rotary_emb onto every block's attention bridge.""" - - def test_sets_rotary_emb_on_all_blocks(self) -> None: - adapter = LlamaArchitectureAdapter(_make_cfg()) - rotary_emb = object() - bridge_model = _DummyBridgeModel([_DummyBlock(), _DummyBlock(), _DummyBlock()]) - - adapter.setup_component_testing(_fake_hf_model(rotary_emb), bridge_model=bridge_model) - - for block in bridge_model.blocks: - assert block.attn.rotary_emb is rotary_emb - - def test_skips_blocks_without_attn(self) -> None: - adapter = LlamaArchitectureAdapter(_make_cfg()) - rotary_emb = object() - bridge_model = _DummyBridgeModel([_DummyBlock(), _DummyBlock(has_attn=False)]) - - adapter.setup_component_testing(_fake_hf_model(rotary_emb), bridge_model=bridge_model) - - assert bridge_model.blocks[0].attn.rotary_emb is rotary_emb - - def test_no_bridge_model_does_not_raise(self) -> None: - """setup_component_testing without a bridge_model must not raise.""" - adapter = LlamaArchitectureAdapter(_make_cfg()) - adapter.setup_component_testing(_fake_hf_model(object()))