diff --git a/tests/unit/model_bridge/supported_architectures/test_bert_adapter.py b/tests/unit/model_bridge/supported_architectures/test_bert_adapter.py new file mode 100644 index 000000000..cdc93343d --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_bert_adapter.py @@ -0,0 +1,240 @@ +"""Unit tests for BertArchitectureAdapter. + +Tests cover: +- Component mapping structure (bridge types and HF module names) +- Learned positional embeddings (pos_embed present; no rotary_emb) +- Weight conversion key set and rearrange patterns (weights + biases) +- Post-LN architecture: supports_fold_ln must remain False +- Anti-drift config flags +""" + +import pytest + +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + PosEmbedBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.bert import ( + BertArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 12, + d_model: int = 768, + n_layers: int = 12, + d_vocab: int = 30522, + n_ctx: int = 512, + **overrides, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for BERT adapter tests.""" + cfg = TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_heads=n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + d_vocab=d_vocab, + architecture="BertForMaskedLM", + ) + for k, v in overrides.items(): + setattr(cfg, k, v) + return cfg + + +@pytest.fixture(scope="module") +def adapter() -> BertArchitectureAdapter: + return BertArchitectureAdapter(_make_cfg()) + + +# --------------------------------------------------------------------------- +# Component mapping +# --------------------------------------------------------------------------- + + +class TestBertComponentMapping: + """Component mapping has the correct slots, bridge types, and HF module paths.""" + + def test_top_level_keys(self, adapter: BertArchitectureAdapter) -> None: + assert set(adapter.component_mapping.keys()) == { + "embed", + "pos_embed", + "blocks", + "ln_final", + "unembed", + } + + def test_has_pos_embed_not_rotary(self, adapter: BertArchitectureAdapter) -> None: + """BERT uses learned positional embeddings — pos_embed present, no rotary_emb.""" + assert "pos_embed" in adapter.component_mapping + assert "rotary_emb" not in adapter.component_mapping + + def test_bridge_types(self, adapter: BertArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert isinstance(mapping["embed"], EmbeddingBridge) + assert isinstance(mapping["pos_embed"], PosEmbedBridge) + assert isinstance(mapping["blocks"], BlockBridge) + assert isinstance(mapping["ln_final"], NormalizationBridge) + assert isinstance(mapping["unembed"], UnembeddingBridge) + + def test_top_level_hf_paths(self, adapter: BertArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping["embed"].name == "bert.embeddings.word_embeddings" + assert mapping["pos_embed"].name == "bert.embeddings.position_embeddings" + assert mapping["blocks"].name == "bert.encoder.layer" + assert mapping["ln_final"].name == "cls.predictions.transform.LayerNorm" + assert mapping["unembed"].name == "cls.predictions.decoder" + + def test_block_submodule_keys(self, adapter: BertArchitectureAdapter) -> None: + assert set(adapter.component_mapping["blocks"].submodules.keys()) == { + "ln1", + "ln2", + "attn", + "mlp", + } + + def test_block_submodule_types(self, adapter: BertArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["ln1"], NormalizationBridge) + assert isinstance(blocks.submodules["ln2"], NormalizationBridge) + assert isinstance(blocks.submodules["attn"], AttentionBridge) + assert isinstance(blocks.submodules["mlp"], MLPBridge) + + def test_block_submodule_hf_paths(self, adapter: BertArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["ln1"].name == "attention.output.LayerNorm" + assert blocks.submodules["ln2"].name == "output.LayerNorm" + assert blocks.submodules["attn"].name == "attention" + assert blocks.submodules["mlp"].name is None + + def test_attn_submodule_keys(self, adapter: BertArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert set(attn.submodules.keys()) == {"q", "k", "v", "o"} + + def test_attn_qkvo_hf_paths(self, adapter: BertArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["q"].name == "self.query" + assert attn.submodules["k"].name == "self.key" + assert attn.submodules["v"].name == "self.value" + assert attn.submodules["o"].name == "output.dense" + + def test_attn_submodules_are_linear_bridges(self, adapter: BertArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + for sub in attn.submodules.values(): + assert isinstance(sub, LinearBridge) + + def test_mlp_submodule_hf_paths(self, adapter: BertArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["in"].name == "intermediate.dense" + assert mlp.submodules["out"].name == "output.dense" + + +# --------------------------------------------------------------------------- +# Anti-drift config flags +# --------------------------------------------------------------------------- + + +class TestBertAdapterConfig: + """Anti-drift flags that must not silently regress.""" + + def test_normalization_type_is_ln(self, adapter: BertArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "LN" + + def test_positional_embedding_type_is_standard(self, adapter: BertArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "standard" + + def test_final_rms_is_false(self, adapter: BertArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is False + + def test_gated_mlp_is_false(self, adapter: BertArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is False + + def test_attn_only_is_false(self, adapter: BertArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_supports_fold_ln_is_false(self, adapter: BertArchitectureAdapter) -> None: + """BERT uses post-LN: fold_ln assumes pre-LN and produces wrong results if enabled.""" + assert adapter.supports_fold_ln is False + + def test_supports_generation_is_false(self) -> None: + """BERT is an encoder-only model — generation is not supported.""" + assert BertArchitectureAdapter.supports_generation is False + + +# --------------------------------------------------------------------------- +# Weight processing conversions +# --------------------------------------------------------------------------- + + +class TestBertWeightConversions: + """weight_processing_conversions has exactly the expected QKV weight+bias and O weight keys.""" + + def test_exact_conversion_key_set(self, adapter: BertArchitectureAdapter) -> None: + assert set(adapter.weight_processing_conversions.keys()) == { + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.q.bias", + "blocks.{i}.attn.k.bias", + "blocks.{i}.attn.v.bias", + "blocks.{i}.attn.o.weight", + } + + def test_qkv_weight_pattern(self, adapter: BertArchitectureAdapter) -> None: + """'(h d_head) d_model -> h d_model d_head' splits heads for Q/K/V weights.""" + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(h d_head) d_model -> h d_model d_head" + + def test_qkv_bias_pattern(self, adapter: BertArchitectureAdapter) -> None: + """'(h d_head) -> h d_head' splits heads for Q/K/V biases.""" + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.bias"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(h d_head) -> h d_head" + + def test_o_weight_pattern(self, adapter: BertArchitectureAdapter) -> None: + """'d_model (h d_head) -> h d_head d_model' for output projection.""" + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "d_model (h d_head) -> h d_head d_model" + + def test_no_norm_conversion_keys(self, adapter: BertArchitectureAdapter) -> None: + """LayerNorm has a bias but no head-splitting needed — no ln conversion entries.""" + keys = set(adapter.weight_processing_conversions.keys()) + assert not any("ln" in k for k in keys) + + def test_no_o_bias_key(self, adapter: BertArchitectureAdapter) -> None: + """Output projection bias is not rearranged — only its weight is converted.""" + assert "blocks.{i}.attn.o.bias" not in adapter.weight_processing_conversions + + def test_qkv_weight_head_axis(self, adapter: BertArchitectureAdapter) -> None: + """h axis in weight conversions matches n_heads=12.""" + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["h"] == 12 + + def test_qkv_bias_head_axis(self, adapter: BertArchitectureAdapter) -> None: + """h axis in bias conversions matches n_heads=12.""" + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.bias"] + assert conv.tensor_conversion.axes_lengths["h"] == 12 diff --git a/tests/unit/model_bridge/supported_architectures/test_falcon_adapter.py b/tests/unit/model_bridge/supported_architectures/test_falcon_adapter.py new file mode 100644 index 000000000..48ac1ef21 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_falcon_adapter.py @@ -0,0 +1,343 @@ +"""Unit tests for FalconArchitectureAdapter. + +Tests cover: +- Component mapping for RoPE+parallel (default), ALiBi, and sequential variants +- Bridge types and HF module paths +- Weight conversion key set and rearrange patterns +- GQA: n_key_value_heads propagates to K/V; multi_query forces n_key_value_heads=1 +- Anti-drift config flags +""" + +import pytest + +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + ALiBiJointQKVAttentionBridge, + BlockBridge, + EmbeddingBridge, + JointQKVPositionEmbeddingsAttentionBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + ParallelBlockBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.falcon import ( + FalconArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 32, + d_model: int = 4096, + n_layers: int = 32, + d_vocab: int = 65024, + n_ctx: int = 2048, + n_key_value_heads: int | None = 8, + alibi: bool = False, + new_decoder_architecture: bool = False, + multi_query: bool = False, + parallel_attn: bool = True, + **overrides, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for Falcon adapter tests.""" + cfg = TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_heads=n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + d_vocab=d_vocab, + n_key_value_heads=n_key_value_heads, + architecture="FalconForCausalLM", + ) + setattr(cfg, "alibi", alibi) + setattr(cfg, "new_decoder_architecture", new_decoder_architecture) + setattr(cfg, "multi_query", multi_query) + setattr(cfg, "parallel_attn", parallel_attn) + for k, v in overrides.items(): + setattr(cfg, k, v) + return cfg + + +@pytest.fixture(scope="module") +def adapter() -> FalconArchitectureAdapter: + """Default RoPE + parallel attention adapter.""" + return FalconArchitectureAdapter(_make_cfg()) + + +# --------------------------------------------------------------------------- +# Component mapping — RoPE + parallel (default) +# --------------------------------------------------------------------------- + + +class TestFalconComponentMapping: + """Component mapping has the correct slots, bridge types, and HF module paths.""" + + def test_top_level_keys_rope(self, adapter: FalconArchitectureAdapter) -> None: + assert set(adapter.component_mapping.keys()) == { + "embed", + "rotary_emb", + "blocks", + "ln_final", + "unembed", + } + + def test_no_pos_embed_key(self, adapter: FalconArchitectureAdapter) -> None: + """Falcon uses rotary or ALiBi — no learned positional embedding component.""" + assert "pos_embed" not in adapter.component_mapping + + def test_bridge_types(self, adapter: FalconArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert isinstance(mapping["embed"], EmbeddingBridge) + assert isinstance(mapping["rotary_emb"], RotaryEmbeddingBridge) + assert isinstance(mapping["ln_final"], NormalizationBridge) + assert isinstance(mapping["unembed"], UnembeddingBridge) + + def test_top_level_hf_paths(self, adapter: FalconArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping["embed"].name == "transformer.word_embeddings" + assert mapping["rotary_emb"].name == "transformer.rotary_emb" + assert mapping["blocks"].name == "transformer.h" + assert mapping["ln_final"].name == "transformer.ln_f" + assert mapping["unembed"].name == "lm_head" + + def test_blocks_is_parallel_block_bridge(self, adapter: FalconArchitectureAdapter) -> None: + """Parallel attention mode uses ParallelBlockBridge.""" + assert isinstance(adapter.component_mapping["blocks"], ParallelBlockBridge) + + def test_parallel_block_has_no_ln2(self, adapter: FalconArchitectureAdapter) -> None: + """Parallel attention shares one LN — no ln2 in default config.""" + assert "ln2" not in adapter.component_mapping["blocks"].submodules + + def test_block_submodule_keys_parallel(self, adapter: FalconArchitectureAdapter) -> None: + assert set(adapter.component_mapping["blocks"].submodules.keys()) == {"ln1", "attn", "mlp"} + + def test_block_submodule_types(self, adapter: FalconArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["ln1"], NormalizationBridge) + assert isinstance(blocks.submodules["attn"], JointQKVPositionEmbeddingsAttentionBridge) + assert isinstance(blocks.submodules["mlp"], MLPBridge) + + def test_ln1_hf_path_standard(self, adapter: FalconArchitectureAdapter) -> None: + """Standard (non-new-arch) uses input_layernorm.""" + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "input_layernorm" + + def test_attn_hf_path(self, adapter: FalconArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["attn"].name == "self_attention" + + def test_attn_submodule_keys(self, adapter: FalconArchitectureAdapter) -> None: + """Falcon has fused qkv projection plus separate q/k/v split-out slots and o.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert set(attn.submodules.keys()) == {"q", "k", "v", "qkv", "o"} + + def test_attn_submodule_hf_paths(self, adapter: FalconArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["qkv"].name == "query_key_value" + assert attn.submodules["o"].name == "dense" + + def test_attn_submodules_are_linear_bridges(self, adapter: FalconArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + for sub in attn.submodules.values(): + assert isinstance(sub, LinearBridge) + + def test_mlp_submodule_hf_paths(self, adapter: FalconArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["in"].name == "dense_h_to_4h" + assert mlp.submodules["out"].name == "dense_4h_to_h" + + +# --------------------------------------------------------------------------- +# ALiBi variant +# --------------------------------------------------------------------------- + + +class TestFalconALiBiVariant: + """ALiBi Falcon uses a different attention bridge and has no rotary_emb key.""" + + def test_no_rotary_emb_key(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(alibi=True)) + assert "rotary_emb" not in adapter.component_mapping + + def test_alibi_top_level_keys(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(alibi=True)) + assert set(adapter.component_mapping.keys()) == { + "embed", + "blocks", + "ln_final", + "unembed", + } + + def test_attn_is_alibi_bridge(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(alibi=True)) + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn, ALiBiJointQKVAttentionBridge) + + def test_positional_embedding_type_is_alibi(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(alibi=True)) + assert adapter.cfg.positional_embedding_type == "alibi" + + +# --------------------------------------------------------------------------- +# Sequential (non-parallel) variant +# --------------------------------------------------------------------------- + + +class TestFalconSequentialVariant: + """Non-parallel Falcon uses BlockBridge and adds ln2.""" + + def test_block_is_block_bridge_not_parallel(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(parallel_attn=False)) + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + assert not isinstance(adapter.component_mapping["blocks"], ParallelBlockBridge) + + def test_sequential_block_has_ln2(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(parallel_attn=False)) + submodules = adapter.component_mapping["blocks"].submodules + assert "ln2" in submodules + assert isinstance(submodules["ln2"], NormalizationBridge) + + def test_ln2_hf_path_sequential(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(parallel_attn=False)) + ln2 = adapter.component_mapping["blocks"].submodules["ln2"] + assert ln2.name == "post_attention_layernorm" + + def test_parallel_attn_mlp_is_false(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(parallel_attn=False)) + assert adapter.cfg.parallel_attn_mlp is False + + +# --------------------------------------------------------------------------- +# New-arch variant (Falcon 40B+) +# --------------------------------------------------------------------------- + + +class TestFalconNewArchVariant: + """New decoder architecture uses ln_attn as the first layer norm name.""" + + def test_ln1_name_is_ln_attn(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(new_decoder_architecture=True)) + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "ln_attn" + + +# --------------------------------------------------------------------------- +# Anti-drift config flags +# --------------------------------------------------------------------------- + + +class TestFalconAdapterConfig: + """Anti-drift flags that must not silently regress.""" + + def test_normalization_type_is_ln(self, adapter: FalconArchitectureAdapter) -> None: + """Falcon uses LayerNorm, not RMSNorm.""" + assert adapter.cfg.normalization_type == "LN" + + def test_gated_mlp_is_false(self, adapter: FalconArchitectureAdapter) -> None: + """Falcon uses a standard (non-gated) MLP.""" + assert adapter.cfg.gated_mlp is False + + def test_parallel_attn_mlp_is_true(self, adapter: FalconArchitectureAdapter) -> None: + """Default Falcon uses parallel attention+MLP.""" + assert adapter.cfg.parallel_attn_mlp is True + + def test_positional_embedding_type_is_rotary(self, adapter: FalconArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + +# --------------------------------------------------------------------------- +# Weight processing conversions +# --------------------------------------------------------------------------- + + +class TestFalconWeightConversions: + """weight_processing_conversions has exactly the expected QKVO keys.""" + + def test_exact_conversion_key_set(self, adapter: FalconArchitectureAdapter) -> None: + assert set(adapter.weight_processing_conversions.keys()) == { + "blocks.{i}.attn.q", + "blocks.{i}.attn.k", + "blocks.{i}.attn.v", + "blocks.{i}.attn.o", + } + + def test_qkv_conversions_use_split_heads_pattern( + self, adapter: FalconArchitectureAdapter + ) -> None: + """'(n h) m -> n m h' splits [n_heads*d_head, d_model] → [n, d_model, d_head].""" + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_conversion_uses_merge_heads_pattern( + self, adapter: FalconArchitectureAdapter + ) -> None: + """'m (n h) -> n h m' moves n to the front for the output projection.""" + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_q_uses_n_heads(self, adapter: FalconArchitectureAdapter) -> None: + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 32 + + def test_kv_use_n_key_value_heads(self, adapter: FalconArchitectureAdapter) -> None: + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}"] + assert conv.tensor_conversion.axes_lengths["n"] == 8 + + def test_o_uses_n_heads(self, adapter: FalconArchitectureAdapter) -> None: + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o"] + assert o_conv.tensor_conversion.axes_lengths["n"] == 32 + + +# --------------------------------------------------------------------------- +# GQA / multi-query support +# --------------------------------------------------------------------------- + + +class TestFalconGQASupport: + """n_key_value_heads propagates to K/V; multi_query forces it to 1.""" + + def test_gqa_propagates_to_kv(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}"] + assert conv.tensor_conversion.axes_lengths["n"] == 8 + + def test_gqa_does_not_affect_q(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 32 + + def test_gqa_does_not_affect_o(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o"] + assert o_conv.tensor_conversion.axes_lengths["n"] == 32 + + def test_no_kv_heads_falls_back_to_n_heads(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=None)) + k_conv = adapter.weight_processing_conversions["blocks.{i}.attn.k"] + assert k_conv.tensor_conversion.axes_lengths["n"] == 32 + + def test_multi_query_sets_kv_heads_to_1(self) -> None: + """multi_query=True overrides n_key_value_heads to 1 on the config.""" + adapter = FalconArchitectureAdapter(_make_cfg(n_heads=32, multi_query=True)) + assert adapter.cfg.n_key_value_heads == 1 + + def test_multi_query_kv_conversion_uses_1_head(self) -> None: + adapter = FalconArchitectureAdapter(_make_cfg(n_heads=32, multi_query=True)) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}"] + assert conv.tensor_conversion.axes_lengths["n"] == 1 diff --git a/tests/unit/model_bridge/supported_architectures/test_mistral_adapter.py b/tests/unit/model_bridge/supported_architectures/test_mistral_adapter.py new file mode 100644 index 000000000..1d423e744 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_mistral_adapter.py @@ -0,0 +1,258 @@ +"""Unit tests for MistralArchitectureAdapter. + +Tests cover: +- Component mapping structure (bridge types and HF module names) +- Weight conversion key set and rearrange patterns +- GQA: n_key_value_heads propagates to K/V conversions only +- Anti-drift config flags +""" + +import pytest + +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + PositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.mistral import ( + MistralArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 32, + d_model: int = 4096, + n_layers: int = 32, + d_vocab: int = 32000, + n_ctx: int = 4096, + n_key_value_heads: int | None = 8, + **overrides, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for Mistral adapter tests.""" + cfg = TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_heads=n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + d_vocab=d_vocab, + n_key_value_heads=n_key_value_heads, + architecture="MistralForCausalLM", + ) + for k, v in overrides.items(): + setattr(cfg, k, v) + return cfg + + +@pytest.fixture(scope="module") +def adapter() -> MistralArchitectureAdapter: + return MistralArchitectureAdapter(_make_cfg()) + + +# --------------------------------------------------------------------------- +# Component mapping — top-level key set and bridge types +# --------------------------------------------------------------------------- + + +class TestMistralComponentMapping: + """Component mapping has the correct slots, bridge types, and HF module paths.""" + + def test_top_level_keys(self, adapter: MistralArchitectureAdapter) -> None: + assert set(adapter.component_mapping.keys()) == { + "embed", + "rotary_emb", + "blocks", + "ln_final", + "unembed", + } + + def test_no_pos_embed_key(self, adapter: MistralArchitectureAdapter) -> None: + """Mistral uses rotary embeddings — no learned positional embedding component.""" + assert "pos_embed" not in adapter.component_mapping + + def test_bridge_types(self, adapter: MistralArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert isinstance(mapping["embed"], EmbeddingBridge) + assert isinstance(mapping["rotary_emb"], RotaryEmbeddingBridge) + assert isinstance(mapping["blocks"], BlockBridge) + assert isinstance(mapping["ln_final"], RMSNormalizationBridge) + assert isinstance(mapping["unembed"], UnembeddingBridge) + + def test_top_level_hf_paths(self, adapter: MistralArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping["embed"].name == "model.embed_tokens" + assert mapping["rotary_emb"].name == "model.rotary_emb" + assert mapping["blocks"].name == "model.layers" + assert mapping["ln_final"].name == "model.norm" + assert mapping["unembed"].name == "lm_head" + + def test_block_submodule_keys(self, adapter: MistralArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert set(blocks.submodules.keys()) == {"ln1", "ln2", "attn", "mlp"} + + def test_block_submodule_types(self, adapter: MistralArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) + assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) + assert isinstance(blocks.submodules["attn"], AttentionBridge) + assert isinstance(blocks.submodules["mlp"], GatedMLPBridge) + + def test_attn_is_not_position_embeddings_subclass( + self, adapter: MistralArchitectureAdapter + ) -> None: + """Mistral uses plain AttentionBridge, not PositionEmbeddingsAttentionBridge.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert not isinstance(attn, PositionEmbeddingsAttentionBridge) + + def test_block_submodule_hf_paths(self, adapter: MistralArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["ln1"].name == "input_layernorm" + assert blocks.submodules["ln2"].name == "post_attention_layernorm" + assert blocks.submodules["attn"].name == "self_attn" + assert blocks.submodules["mlp"].name == "mlp" + + def test_attn_requires_mask_and_position_embeddings( + self, adapter: MistralArchitectureAdapter + ) -> None: + """Mistral RoPE attention requires both attention mask and position embeddings.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.requires_attention_mask is True + assert attn.requires_position_embeddings is True + + def test_attn_qkvo_submodule_paths(self, adapter: MistralArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert set(attn.submodules.keys()) == {"q", "k", "v", "o"} + assert attn.submodules["q"].name == "q_proj" + assert attn.submodules["k"].name == "k_proj" + assert attn.submodules["v"].name == "v_proj" + assert attn.submodules["o"].name == "o_proj" + + def test_attn_qkvo_are_linear_bridges(self, adapter: MistralArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + for sub in attn.submodules.values(): + assert isinstance(sub, LinearBridge) + + def test_mlp_submodule_paths(self, adapter: MistralArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert set(mlp.submodules.keys()) == {"gate", "in", "out"} + assert mlp.submodules["gate"].name == "gate_proj" + assert mlp.submodules["in"].name == "up_proj" + assert mlp.submodules["out"].name == "down_proj" + + +# --------------------------------------------------------------------------- +# Anti-drift config flags +# --------------------------------------------------------------------------- + + +class TestMistralAdapterConfig: + """Anti-drift flags that must not silently regress.""" + + def test_final_rms_is_false(self, adapter: MistralArchitectureAdapter) -> None: + """Mistral does not use final RMSNorm — final_rms must remain False.""" + assert adapter.cfg.final_rms is False + + def test_uses_rms_norm_is_true(self, adapter: MistralArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_gated_mlp_is_true(self, adapter: MistralArchitectureAdapter) -> None: + """Mistral uses a gated SwiGLU MLP — must not silently revert to vanilla MLP.""" + assert adapter.cfg.gated_mlp is True + + def test_attn_only_is_false(self, adapter: MistralArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + +# --------------------------------------------------------------------------- +# Weight processing conversions +# --------------------------------------------------------------------------- + + +class TestMistralWeightConversions: + """weight_processing_conversions has exactly the expected QKVO keys.""" + + def test_exact_conversion_key_set(self, adapter: MistralArchitectureAdapter) -> None: + assert set(adapter.weight_processing_conversions.keys()) == { + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + } + + def test_qkv_conversions_use_split_heads_pattern( + self, adapter: MistralArchitectureAdapter + ) -> None: + """'(n h) m -> n m h' splits [n_heads*d_head, d_model] → [n, d_model, d_head].""" + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_conversion_uses_merge_heads_pattern( + self, adapter: MistralArchitectureAdapter + ) -> None: + """'m (n h) -> n h m' moves n to the front for the output projection.""" + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_no_bias_conversion_keys(self, adapter: MistralArchitectureAdapter) -> None: + """Mistral has no attention biases — no bias conversion entries.""" + keys = set(adapter.weight_processing_conversions.keys()) + assert not any("bias" in k or ".b_" in k for k in keys) + + def test_no_norm_conversion_keys(self, adapter: MistralArchitectureAdapter) -> None: + """RMSNorm has no bias offset — no ln1/ln2/ln_final conversion entries.""" + keys = set(adapter.weight_processing_conversions.keys()) + assert not any("ln" in k for k in keys) + + +# --------------------------------------------------------------------------- +# GQA support +# --------------------------------------------------------------------------- + + +class TestMistralGQASupport: + """n_key_value_heads must propagate to K/V conversions and leave Q/O unchanged.""" + + def test_no_kv_heads_falls_back_to_n_heads(self) -> None: + """Without n_key_value_heads, K/V fall back to n_heads.""" + adapter = MistralArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=None)) + k_conv = adapter.weight_processing_conversions["blocks.{i}.attn.k.weight"] + assert k_conv.tensor_conversion.axes_lengths["n"] == 32 + + def test_gqa_propagates_to_kv_conversions(self) -> None: + """With 8 KV heads, K/V conversions must use n=8.""" + adapter = MistralArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 8 + + def test_gqa_does_not_affect_q_conversion(self) -> None: + """Q always uses full n_heads regardless of GQA.""" + adapter = MistralArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 32 + + def test_gqa_does_not_affect_o_conversion(self) -> None: + """O projection always uses n_heads; GQA only affects K/V.""" + adapter = MistralArchitectureAdapter(_make_cfg(n_heads=32, n_key_value_heads=8)) + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert o_conv.tensor_conversion.axes_lengths["n"] == 32