Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/integration/model_bridge/test_gemma4_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Integration tests for the Gemma 4 TransformerBridge.

Uses tiny random-init `Gemma4ForConditionalGeneration` fixtures (4 layers, d_model 8) so CI
stays light while still exercising the real per-layer heterogeneity across the family:

- ``tiny-random/gemma-4-e`` — Per-Layer Embeddings + KV-cache sharing (E2B/E4B shape)
- ``tiny-random/gemma-4-dense`` — K==V attention on global layers, no v_proj (31B shape)
- ``tiny-random/gemma-4-moe`` — router + batched experts beside the dense MLP (26B-A4B shape)

Confirms logit parity vs HF (the block bridge defers all math to HF) and that hooks fire on
the conventional single-stream residual.
"""

import pytest
import torch

from transformer_lens.model_bridge import TransformerBridge

MODEL_NAMES = {
"ple_kv_shared": "tiny-random/gemma-4-e",
"dense_k_eq_v": "tiny-random/gemma-4-dense",
"moe": "tiny-random/gemma-4-moe",
}
IDS = torch.tensor([[1, 2, 3, 4, 5]])


@pytest.fixture(scope="module", params=list(MODEL_NAMES), ids=list(MODEL_NAMES))
def bridge(request):
return TransformerBridge.boot_transformers(
MODEL_NAMES[request.param], device="cpu", dtype=torch.float32
)


def test_text_only_logit_parity_vs_hf(bridge):
from transformers import AutoModelForCausalLM

hf = AutoModelForCausalLM.from_pretrained(
bridge.cfg.model_name, torch_dtype=torch.float32, attn_implementation="eager"
).eval()
with torch.no_grad():
ref = hf(IDS).logits
out = bridge.forward(IDS, return_type="logits")
assert out.shape == ref.shape
# PLE / KV-sharing / K==V / MoE all run inside HF — the bridge is a pass-through.
assert torch.max(torch.abs(out - ref)).item() < 1e-3


def test_config_from_text_config(bridge):
# Text dims resolve from the nested text_config of the multimodal model.
assert bridge.cfg.n_layers == 4
assert getattr(bridge.cfg, "is_multimodal", False) is True


def test_resid_hooks_fire_with_conventional_shape(bridge):
"""The residual stream is a single conventional (batch, seq, d_model) tensor."""
captured = {}

def cap(tensor, hook):
captured[hook.name] = tensor.detach()
return tensor

names = [
n
for n in bridge.hook_dict
if n.endswith("blocks.0.hook_resid_pre") or n.endswith("blocks.0.hook_resid_post")
]
assert names, "no residual hooks registered"
with torch.no_grad():
bridge.run_with_hooks(IDS, fwd_hooks=[(n, cap) for n in names])

assert captured, "residual hooks did not fire"
for tensor in captured.values():
assert tensor.shape == (IDS.shape[0], IDS.shape[1], bridge.cfg.d_model)


def test_run_with_cache_text_only(bridge):
with torch.no_grad():
logits, cache = bridge.run_with_cache(IDS)
assert torch.isfinite(logits).all()
assert len(cache) > 0
166 changes: 166 additions & 0 deletions tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Unit tests for the Gemma 4 architecture adapter."""

from types import SimpleNamespace

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.model_bridge.generalized_components import (
DelegatedAttentionBlockBridge,
EmbeddingBridge,
LinearBridge,
RotaryEmbeddingBridge,
UnembeddingBridge,
)
from transformer_lens.model_bridge.generalized_components.base import (
GeneralizedComponent,
)
from transformer_lens.model_bridge.supported_architectures.gemma4 import (
Gemma4ArchitectureAdapter,
)

ARCH = "Gemma4ForConditionalGeneration"
ARCH_UNIFIED = "Gemma4UnifiedForConditionalGeneration"


def _cfg(arch: str = ARCH, **kwargs) -> TransformerBridgeConfig:
cfg = TransformerBridgeConfig(
d_model=1536,
d_head=256,
n_heads=8,
n_layers=35,
n_ctx=131072,
d_vocab=262144,
n_key_value_heads=1,
architecture=arch,
**kwargs,
)
# Both variants are multimodal (have vision_config + embed_vision).
cfg.vision_config = SimpleNamespace(
hidden_size=2048,
num_hidden_layers=27,
num_attention_heads=16,
)
cfg.vision_soft_tokens_per_image = 256
return cfg


def _adapter(arch: str = ARCH) -> Gemma4ArchitectureAdapter:
return Gemma4ArchitectureAdapter(_cfg(arch))


def test_config_flags():
a = _adapter()
# Multimodal (Gemma4ForConditionalGeneration has vision tower + projector).
assert a.cfg.is_multimodal is True
# PLE / layer_scalar / MoE residual topology is not fold-safe.
assert a.supports_fold_ln is False
assert a.weight_processing_conversions == {}
assert a.cfg.normalization_type == "RMS"
# Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3.
assert a.cfg.rmsnorm_uses_offset is False
assert a.cfg.positional_embedding_type == "rotary"
assert a.applicable_phases == [1, 2, 4]


def test_config_flags_unified():
"""Gemma4UnifiedForConditionalGeneration (12B) is encoder-free but still multimodal:
has model.embed_vision (raw-patch projector) but no model.vision_tower."""
a = _adapter(ARCH_UNIFIED)
assert a.cfg.is_multimodal is True
assert "vision_encoder" not in a.component_mapping
assert "vision_projector" in a.component_mapping
assert a.component_mapping["vision_projector"].name == "model.embed_vision"


def test_text_path_nested_under_language_model():
m = _adapter().component_mapping
assert m["embed"].name == "model.language_model.embed_tokens"
assert m["rotary_emb"].name == "model.language_model.rotary_emb"
assert m["blocks"].name == "model.language_model.layers"
assert m["ln_final"].name == "model.language_model.norm"
assert m["unembed"].name == "lm_head"
assert isinstance(m["embed"], EmbeddingBridge)
assert isinstance(m["rotary_emb"], RotaryEmbeddingBridge)
assert isinstance(m["blocks"], DelegatedAttentionBlockBridge)
assert isinstance(m["unembed"], UnembeddingBridge)


def test_vision_components_present_for_multimodal():
"""Gemma4ForConditionalGeneration has vision_tower + embed_vision."""
m = _adapter().component_mapping
assert "vision_encoder" in m
assert "vision_projector" in m
assert m["vision_encoder"].name == "model.vision_tower"
assert m["vision_projector"].name == "model.embed_vision"
assert isinstance(m["vision_projector"], GeneralizedComponent)
# Vision config fields extracted from vision_config.
a = _adapter()
assert a.cfg.vision_hidden_size == 2048
assert a.cfg.vision_num_layers == 27
assert a.cfg.vision_num_heads == 16
assert a.cfg.mm_tokens_per_image == 256


def test_block_decomposition():
blocks = _adapter().component_mapping["blocks"]
for name in ("attn", "mlp"):
assert name in blocks.submodules
# Sandwich norms (same shape as Gemma 2/3) under canonical keys.
for norm in ("ln1", "ln1_post", "ln2", "ln2_post"):
assert norm in blocks.submodules
assert blocks.submodules[norm].optional is False


def test_split_qkv_fork_aliases_absent():
"""Attention is delegated wholesale to HF; per-layer structure is heterogeneous
(KV-shared layers have no k/v projections), so the split-qkv fork aliases
do not apply."""
blocks = _adapter().component_mapping["blocks"]
for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"):
assert alias not in blocks.hook_aliases
# The single-stream residual aliases remain, redirected through the sandwich norms.
assert blocks.hook_aliases["hook_resid_mid"] == "ln2.hook_in"
assert blocks.hook_aliases["hook_attn_out"] == "ln1_post.hook_out"
assert blocks.hook_aliases["hook_mlp_out"] == "ln2_post.hook_out"


def test_kv_shared_and_k_eq_v_submodules_are_optional():
"""KV-shared layers (E2B/E4B) drop k/v proj + norms; K==V global-attention
layers (31B / 26B-A4B) drop v_proj."""
attn = _adapter().component_mapping["blocks"].submodules["attn"]
assert attn.submodules["q"].optional is False
assert attn.submodules["o"].optional is False
assert attn.submodules["q_norm"].optional is False
for shared in ("k", "v", "k_norm", "v_norm"):
assert attn.submodules[shared].optional is True
assert isinstance(attn.submodules["q"], LinearBridge)


def test_per_layer_embedding_submodules_are_optional():
"""PLE modules exist only when hidden_size_per_layer_input > 0 (E2B/E4B)."""
blocks = _adapter().component_mapping["blocks"]
for name in (
"per_layer_input_gate",
"per_layer_projection",
"post_per_layer_input_norm",
):
assert blocks.submodules[name].optional is True


def test_moe_submodules_are_optional():
"""MoE branch exists only when enable_moe_block (26B-A4B)."""
blocks = _adapter().component_mapping["blocks"]
for name in (
"router",
"experts",
"pre_feedforward_layernorm_2",
"post_feedforward_layernorm_1",
"post_feedforward_layernorm_2",
):
assert blocks.submodules[name].optional is True


def test_gated_mlp_decomposition():
mlp = _adapter().component_mapping["blocks"].submodules["mlp"]
assert mlp.submodules["gate"].name == "gate_proj"
assert mlp.submodules["in"].name == "up_proj"
assert mlp.submodules["out"].name == "down_proj"
3 changes: 2 additions & 1 deletion transformer_lens/benchmarks/main_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,8 @@ def cleanup_model(model, model_name_str: str):
model_name, trust_remote_code=trust_remote_code, token=_hf_token()
)
if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__:
hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None)
eos = getattr(hf_config, "eos_token_id", None)
hf_config.pad_token_id = eos[0] if isinstance(eos, (list, tuple)) else eos
hf_kwargs["config"] = hf_config
if trust_remote_code:
hf_kwargs["trust_remote_code"] = True
Expand Down
13 changes: 7 additions & 6 deletions transformer_lens/benchmarks/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
through forward(), generate(), and run_with_cache().
"""


import torch

from transformer_lens.benchmarks.utils import (
Expand Down Expand Up @@ -46,8 +45,10 @@ def _prepare_test_inputs(bridge: TransformerBridge):
# Different models use different tokens:
# LLava: image_token = "<image>"
# Gemma3: boi_token = "<start_of_image>"
image_token = getattr(bridge.processor, "boi_token", None) or getattr(
bridge.processor, "image_token", "<image>"
# Gemma4: image_token is the expandable placeholder (280 tokens),
# boi_token ("<|image>") is just a marker — use image_token first.
image_token = getattr(bridge.processor, "image_token", None) or getattr(
bridge.processor, "boi_token", "<image>"
)
prompt = f"{image_token}\nDescribe this image."
try:
Expand Down Expand Up @@ -141,9 +142,9 @@ def benchmark_multimodal_forward(
details={
"logits_shape": list(logits.shape),
"input_ids_shape": list(input_ids.shape),
"pixel_values_shape": list(pixel_values.shape)
if pixel_values is not None
else None,
"pixel_values_shape": (
list(pixel_values.shape) if pixel_values is not None else None
),
},
)

Expand Down
6 changes: 6 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Gemma3ArchitectureAdapter,
Gemma3MultimodalArchitectureAdapter,
Gemma3nArchitectureAdapter,
Gemma4ArchitectureAdapter,
GPT2ArchitectureAdapter,
Gpt2LmHeadCustomArchitectureAdapter,
GPTBigCodeArchitectureAdapter,
Expand Down Expand Up @@ -86,6 +87,11 @@
"Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
"Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
"Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter,
"Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter,
# The unified (encoder-free) 12B variant's text decoder is a strict structural
# subset of gemma4 (no PLE, no MoE — both optional in the adapter), with the
# same module paths. Requires transformers >= 5.10 to load.
"Gemma4UnifiedForConditionalGeneration": Gemma4ArchitectureAdapter,
"GraniteForCausalLM": GraniteArchitectureAdapter,
"GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from transformer_lens.model_bridge.generalized_components.block import (
BlockBridge,
DelegatedAttentionBlockBridge,
MLABlockBridge,
ParallelBlockBridge,
)
Expand Down Expand Up @@ -108,6 +109,7 @@
"AttentionBridge",
"AudioFeatureExtractorBridge",
"BlockBridge",
"DelegatedAttentionBlockBridge",
"MLABlockBridge",
"ParallelBlockBridge",
"BloomBlockBridge",
Expand Down
31 changes: 31 additions & 0 deletions transformer_lens/model_bridge/generalized_components/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,34 @@ def __init__(
if self.hook_aliases is BlockBridge.hook_aliases:
self.hook_aliases = dict(self.hook_aliases)
self.hook_aliases.pop("hook_resid_mid", None)


class DelegatedAttentionBlockBridge(BlockBridge):
"""Block whose attention is delegated wholesale to HF (no split-qkv fork).

For architectures with heterogeneous per-layer attention structure — e.g.
Gemma 4, where KV-shared layers have no ``k_proj``/``v_proj`` at all and
K==V layers have no ``v_proj`` — there is no uniform HookPoint that
represents "input that becomes Q/K/V", so the block-level ``hook_q_input``/
``hook_k_input``/``hook_v_input``/``hook_attn_in`` aliases do not apply.
Type-level distinction means a reader of the adapter sees
``DelegatedAttentionBlockBridge`` and knows those hooks are absent.
"""

def __init__(
self,
name: str,
config: Optional[Any] = None,
submodules: Optional[Dict[str, GeneralizedComponent]] = None,
hook_alias_overrides: Optional[Dict[str, str]] = None,
):
super().__init__(
name,
config=config,
submodules=submodules,
hook_alias_overrides=hook_alias_overrides,
)
if self.hook_aliases is BlockBridge.hook_aliases:
self.hook_aliases = dict(self.hook_aliases)
for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"):
self.hook_aliases.pop(alias, None)
4 changes: 4 additions & 0 deletions transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def determine_architecture_from_hf_config(hf_config):
# gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration
# (vision/audio referenced but unbridged in the text-only adapter).
"gemma3n": "Gemma3nForConditionalGeneration",
# gemma4 is multimodal-only; all released checkpoints load as the full
# ForConditionalGeneration (vision/audio referenced but unbridged).
"gemma4": "Gemma4ForConditionalGeneration",
"gemma4_unified": "Gemma4UnifiedForConditionalGeneration",
"bert": "BertForMaskedLM",
"bloom": "BloomForCausalLM",
"codegen": "CodeGenForCausalLM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from transformer_lens.model_bridge.supported_architectures.gemma3n import (
Gemma3nArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gemma4 import (
Gemma4ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gpt2 import (
GPT2ArchitectureAdapter,
)
Expand Down Expand Up @@ -189,6 +192,7 @@
"Gemma3ArchitectureAdapter",
"Gemma3nArchitectureAdapter",
"Gemma3MultimodalArchitectureAdapter",
"Gemma4ArchitectureAdapter",
"GraniteArchitectureAdapter",
"GraniteMoeArchitectureAdapter",
"GraniteMoeHybridArchitectureAdapter",
Expand Down
Loading
Loading