Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Unit tests for the Lfm2MoeArchitectureAdapter — no model downloads."""

import pytest

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.model_bridge.generalized_components import (
EmbeddingBridge,
RMSNormalizationBridge,
UnembeddingBridge,
)
from transformer_lens.model_bridge.supported_architectures.lfm2_moe import (
Lfm2MoeArchitectureAdapter,
Lfm2MoeBlockBridge,
)


@pytest.fixture(scope="class")
def cfg() -> TransformerBridgeConfig:
bridge_cfg = TransformerBridgeConfig(
d_model=64,
d_head=16,
n_layers=4,
n_ctx=128,
n_heads=4,
n_key_value_heads=2,
d_vocab=256,
d_mlp=224,
architecture="Lfm2MoeForCausalLM",
)
bridge_cfg.layer_types = ["conv", "conv", "full_attention", "conv"]
bridge_cfg.moe_intermediate_size = 56
bridge_cfg.num_experts = 8
bridge_cfg.experts_per_token = 2
bridge_cfg.norm_eps = 1e-5
bridge_cfg.rope_parameters = {"rope_theta": 5_000_000, "rope_type": "default"}
return bridge_cfg


@pytest.fixture(scope="class")
def adapter(cfg: TransformerBridgeConfig) -> Lfm2MoeArchitectureAdapter:
return Lfm2MoeArchitectureAdapter(cfg)


class TestLfm2MoeAdapterConfig:
def test_hybrid_config_is_propagated(self, adapter: Lfm2MoeArchitectureAdapter) -> None:
assert adapter.cfg.layer_types == ["conv", "conv", "full_attention", "conv"]
assert adapter.cfg.moe_intermediate_size == 56
assert adapter.cfg.num_experts == 8
assert adapter.cfg.experts_per_token == 2

def test_norm_and_rope_config(self, adapter: Lfm2MoeArchitectureAdapter) -> None:
assert adapter.cfg.normalization_type == "RMS"
assert adapter.cfg.positional_embedding_type == "rotary"
assert adapter.cfg.eps == 1e-5
assert adapter.cfg.rotary_base == 5_000_000

def test_default_prepend_bos_is_false(self, adapter: Lfm2MoeArchitectureAdapter) -> None:
assert adapter.cfg.default_prepend_bos is False


class TestLfm2MoeComponentMapping:
def test_has_residual_only_top_level_mapping(
self, adapter: Lfm2MoeArchitectureAdapter
) -> None:
mapping = adapter.component_mapping
assert mapping is not None
assert set(mapping) == {"embed", "blocks", "ln_final", "unembed"}

def test_component_types(self, adapter: Lfm2MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert isinstance(mapping["embed"], EmbeddingBridge)
assert isinstance(mapping["blocks"], Lfm2MoeBlockBridge)
assert isinstance(mapping["ln_final"], RMSNormalizationBridge)
assert isinstance(mapping["unembed"], UnembeddingBridge)

def test_hf_module_paths(self, adapter: Lfm2MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping["embed"].name == "model.embed_tokens"
assert mapping["blocks"].name == "model.layers"
assert mapping["ln_final"].name == "model.embedding_norm"
assert mapping["unembed"].name == "lm_head"

def test_blocks_only_advertise_supported_residual_aliases(
self, adapter: Lfm2MoeArchitectureAdapter
) -> None:
blocks = adapter.component_mapping["blocks"]
assert blocks.hook_aliases == {
"hook_resid_pre": "hook_in",
"hook_resid_post": "hook_out",
}
assert blocks.submodules == {}
123 changes: 123 additions & 0 deletions tests/unit/model_bridge/supported_architectures/test_phimoe_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Unit tests for the PhiMoEArchitectureAdapter - no model downloads."""

import pytest

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import (
RearrangeTensorConversion,
)
from transformer_lens.conversion_utils.param_processing_conversion import (
ParamProcessingConversion,
)
from transformer_lens.model_bridge.generalized_components import (
AttentionBridge,
BlockBridge,
EmbeddingBridge,
LinearBridge,
MoEBridge,
NormalizationBridge,
UnembeddingBridge,
)
from transformer_lens.model_bridge.supported_architectures.phimoe import (
PhiMoEArchitectureAdapter,
)


@pytest.fixture(scope="class")
def cfg() -> TransformerBridgeConfig:
bridge_cfg = TransformerBridgeConfig(
d_model=64,
d_head=16,
n_layers=2,
n_ctx=128,
n_heads=4,
n_key_value_heads=2,
d_vocab=256,
d_mlp=32,
architecture="PhiMoEForCausalLM",
)
bridge_cfg.num_experts = 8
bridge_cfg.experts_per_token = 2
bridge_cfg.attention_bias = True
bridge_cfg.lm_head_bias = True
bridge_cfg.router_jitter_noise = 0.01
bridge_cfg.input_jitter_noise = 0.01
bridge_cfg.rope_parameters = {"rope_theta": 10_000.0, "rope_type": "default"}
bridge_cfg.eos_token_id = 32000
return bridge_cfg


@pytest.fixture(scope="class")
def adapter(cfg: TransformerBridgeConfig) -> PhiMoEArchitectureAdapter:
return PhiMoEArchitectureAdapter(cfg)


class TestPhiMoEAdapterConfig:
def test_config_flags(self, adapter: PhiMoEArchitectureAdapter) -> None:
assert adapter.cfg.normalization_type == "LN"
assert adapter.cfg.positional_embedding_type == "rotary"
assert adapter.cfg.attn_implementation == "eager"
assert adapter.cfg.default_prepend_bos is False
assert adapter.cfg.rotary_base == 10_000.0
assert adapter.cfg.eos_token_id == [32000, 32007]

def test_moe_config_is_propagated(self, adapter: PhiMoEArchitectureAdapter) -> None:
assert adapter.cfg.num_experts == 8
assert adapter.cfg.experts_per_token == 2
assert adapter.cfg.router_jitter_noise == 0.01
assert adapter.cfg.input_jitter_noise == 0.01


class TestPhiMoEWeightConversions:
def test_conversion_keys_include_attention_biases(
self, adapter: PhiMoEArchitectureAdapter
) -> None:
assert set(adapter.weight_processing_conversions) == {
"blocks.{i}.attn.q.weight",
"blocks.{i}.attn.k.weight",
"blocks.{i}.attn.v.weight",
"blocks.{i}.attn.o.weight",
"blocks.{i}.attn.q.bias",
"blocks.{i}.attn.k.bias",
"blocks.{i}.attn.v.bias",
}

def test_kv_conversions_use_n_key_value_heads(
self, adapter: PhiMoEArchitectureAdapter
) -> None:
for key in ("blocks.{i}.attn.k.weight", "blocks.{i}.attn.v.weight"):
conv = adapter.weight_processing_conversions[key]
assert isinstance(conv, ParamProcessingConversion)
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
assert conv.tensor_conversion.axes_lengths["n"] == 2


class TestPhiMoEComponentMapping:
def test_component_types(self, adapter: PhiMoEArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert isinstance(mapping["embed"], EmbeddingBridge)
assert isinstance(mapping["blocks"], BlockBridge)
assert isinstance(mapping["ln_final"], NormalizationBridge)
assert isinstance(mapping["unembed"], UnembeddingBridge)

def test_hf_paths(self, adapter: PhiMoEArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping["embed"].name == "model.embed_tokens"
assert mapping["blocks"].name == "model.layers"
assert mapping["ln_final"].name == "model.norm"
assert mapping["unembed"].name == "lm_head"

subs = mapping["blocks"].submodules
assert subs["ln1"].name == "input_layernorm"
assert subs["ln2"].name == "post_attention_layernorm"
assert subs["attn"].name == "self_attn"
assert subs["mlp"].name == "mlp"

def test_block_submodule_types(self, adapter: PhiMoEArchitectureAdapter) -> None:
subs = adapter.component_mapping["blocks"].submodules
assert isinstance(subs["ln1"], NormalizationBridge)
assert isinstance(subs["ln2"], NormalizationBridge)
assert isinstance(subs["attn"], AttentionBridge)
assert isinstance(subs["mlp"], MoEBridge)
assert isinstance(subs["mlp"].submodules["gate"], LinearBridge)
assert subs["mlp"].submodules["gate"].name == "router"
4 changes: 4 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
LlavaArchitectureAdapter,
LlavaNextArchitectureAdapter,
LlavaOnevisionArchitectureAdapter,
Lfm2MoeArchitectureAdapter,
Mamba2ArchitectureAdapter,
MambaArchitectureAdapter,
MingptArchitectureAdapter,
Expand All @@ -56,6 +57,7 @@
OptArchitectureAdapter,
Phi3ArchitectureAdapter,
PhiArchitectureAdapter,
PhiMoEArchitectureAdapter,
Qwen2ArchitectureAdapter,
Qwen3_5ArchitectureAdapter,
Qwen3_5MultimodalArchitectureAdapter,
Expand Down Expand Up @@ -101,6 +103,7 @@
"LlavaForConditionalGeneration": LlavaArchitectureAdapter,
"LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
"LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter,
"Lfm2MoeForCausalLM": Lfm2MoeArchitectureAdapter,
"Mamba2ForCausalLM": Mamba2ArchitectureAdapter,
"MambaForCausalLM": MambaArchitectureAdapter,
"MixtralForCausalLM": MixtralArchitectureAdapter,
Expand All @@ -117,6 +120,7 @@
"OPTForCausalLM": OptArchitectureAdapter,
"PhiForCausalLM": PhiArchitectureAdapter,
"Phi3ForCausalLM": Phi3ArchitectureAdapter,
"PhiMoEForCausalLM": PhiMoEArchitectureAdapter,
"QwenForCausalLM": QwenArchitectureAdapter,
"Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
"Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
Expand Down
30 changes: 20 additions & 10 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,11 +2697,16 @@ def generate(
self.tokenizer is not None and self.tokenizer.eos_token_id is not None
)
if eos_token_id is None:
assert (
tokenizer_has_eos_token
), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
assert self.tokenizer is not None
eos_token_id = self.tokenizer.eos_token_id
# Some chat models use a turn-end token that differs from the
# tokenizer's primary EOS. Let adapters provide the full stop
# set via cfg.eos_token_id; otherwise fall back to the tokenizer.
eos_token_id = getattr(self.cfg, "eos_token_id", None)
if eos_token_id is None:
assert (
tokenizer_has_eos_token
), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
assert self.tokenizer is not None
eos_token_id = self.tokenizer.eos_token_id

if isinstance(eos_token_id, int):
stop_tokens = [eos_token_id]
Expand Down Expand Up @@ -3018,11 +3023,16 @@ def generate_stream(
self.tokenizer is not None and self.tokenizer.eos_token_id is not None
)
if eos_token_id is None:
assert (
tokenizer_has_eos_token
), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
assert self.tokenizer is not None
eos_token_id = self.tokenizer.eos_token_id
# Some chat models use a turn-end token that differs from the
# tokenizer's primary EOS. Let adapters provide the full stop
# set via cfg.eos_token_id; otherwise fall back to the tokenizer.
eos_token_id = getattr(self.cfg, "eos_token_id", None)
if eos_token_id is None:
assert (
tokenizer_has_eos_token
), "Must pass eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
assert self.tokenizer is not None
eos_token_id = self.tokenizer.eos_token_id
if isinstance(eos_token_id, int):
stop_tokens = [eos_token_id]
eos_token_for_padding = eos_token_id
Expand Down
9 changes: 9 additions & 0 deletions transformer_lens/model_bridge/sources/_bridge_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@
# Cohere
"logit_scale",
"rope_parameters",
# Hybrid/MoE architectures
"layer_types",
"moe_intermediate_size",
"norm_eps",
"attention_bias",
"lm_head_bias",
"router_jitter_noise",
"input_jitter_noise",
"eos_token_id",
]


Expand Down
11 changes: 11 additions & 0 deletions transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def map_default_transformer_lens_config(hf_config):
tl_config.eps = source_config.layer_norm_eps
elif hasattr(source_config, "layer_norm_epsilon"):
tl_config.eps = source_config.layer_norm_epsilon
elif hasattr(source_config, "norm_eps"):
tl_config.eps = source_config.norm_eps
if hasattr(source_config, "num_local_experts"):
tl_config.num_experts = source_config.num_local_experts
if hasattr(source_config, "num_experts_per_tok"):
Expand Down Expand Up @@ -516,6 +518,15 @@ def boot(
# Cohere
"logit_scale",
"rope_parameters",
# Hybrid/MoE architectures
"layer_types",
"moe_intermediate_size",
"norm_eps",
"attention_bias",
"lm_head_bias",
"router_jitter_noise",
"input_jitter_noise",
"eos_token_id",
]
for attr in _HF_PASSTHROUGH_ATTRS:
val = getattr(hf_config, attr, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@
from transformer_lens.model_bridge.supported_architectures.llava_onevision import (
LlavaOnevisionArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.lfm2_moe import (
Lfm2MoeArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.mamba import (
MambaArchitectureAdapter,
)
Expand Down Expand Up @@ -141,6 +144,9 @@
from transformer_lens.model_bridge.supported_architectures.phi3 import (
Phi3ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.phimoe import (
PhiMoEArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.qwen import (
QwenArchitectureAdapter,
)
Expand Down Expand Up @@ -203,6 +209,7 @@
"LlavaArchitectureAdapter",
"LlavaNextArchitectureAdapter",
"LlavaOnevisionArchitectureAdapter",
"Lfm2MoeArchitectureAdapter",
"MambaArchitectureAdapter",
"Mamba2ArchitectureAdapter",
"MingptArchitectureAdapter",
Expand All @@ -222,6 +229,7 @@
"OptArchitectureAdapter",
"PhiArchitectureAdapter",
"Phi3ArchitectureAdapter",
"PhiMoEArchitectureAdapter",
"QwenArchitectureAdapter",
"Qwen2ArchitectureAdapter",
"Qwen3ArchitectureAdapter",
Expand Down
Loading
Loading