From 3be749cfd964b2dad44cbefba25f39a4337babf6 Mon Sep 17 00:00:00 2001 From: punishell Date: Wed, 10 Jun 2026 14:31:32 +0200 Subject: [PATCH 01/10] Add Gemma 4 architecture support to TransformerBridge Adds a text-only adapter covering both Gemma4ForConditionalGeneration (E2B/E4B/31B/26B-A4B) and Gemma4UnifiedForConditionalGeneration (12B), addressing #1297. Gemma 4 layers are heterogeneous: KV-shared layers drop k/v projections, K==V layers drop v_proj, and per-layer-embedding / MoE submodules appear only on some variants -- all mapped optional and delegated to HF. Unlike Gemma 1-3, Gemma4RMSNorm has no (1+weight) offset. Adds DelegatedAttentionBlockBridge (drops the split-QKV fork aliases, as MLABlockBridge does) so hook-alias resolution stays clean when attention is delegated wholesale to HF. google/gemma-4-E2B-it passes verification (P1 100%, P2 100%, P4 94.7%). - New adapter + four-place registration + gemma4/gemma4_unified model_type mappings - 10 checkpoints added to the model registry - Unit + integration tests (logit parity vs HF on all three structural variants) --- .../model_bridge/test_gemma4_bridge.py | 80 ++++++++++ .../test_gemma4_adapter.py | 124 +++++++++++++++ .../factories/architecture_adapter_factory.py | 6 + .../generalized_components/__init__.py | 2 + .../generalized_components/block.py | 31 ++++ .../model_bridge/sources/transformers.py | 4 + .../supported_architectures/__init__.py | 4 + .../supported_architectures/gemma4.py | 139 +++++++++++++++++ .../tools/model_registry/__init__.py | 4 + .../model_registry/data/supported_models.json | 146 +++++++++++++++++- .../data/verification_history.json | 12 +- .../tools/model_registry/generate_report.py | 2 + 12 files changed, 550 insertions(+), 4 deletions(-) create mode 100644 tests/integration/model_bridge/test_gemma4_bridge.py create mode 100644 tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py create mode 100644 transformer_lens/model_bridge/supported_architectures/gemma4.py diff --git a/tests/integration/model_bridge/test_gemma4_bridge.py b/tests/integration/model_bridge/test_gemma4_bridge.py new file mode 100644 index 000000000..a8e67dc28 --- /dev/null +++ b/tests/integration/model_bridge/test_gemma4_bridge.py @@ -0,0 +1,80 @@ +"""Integration tests for the Gemma 4 text-only TransformerBridge. + +Uses tiny random-init `Gemma4ForConditionalGeneration` fixtures (4 layers, d_model 8) so CI +stays light while still exercising the real per-layer heterogeneity across the family: + +- ``tiny-random/gemma-4-e`` — Per-Layer Embeddings + KV-cache sharing (E2B/E4B shape) +- ``tiny-random/gemma-4-dense`` — K==V attention on global layers, no v_proj (31B shape) +- ``tiny-random/gemma-4-moe`` — router + batched experts beside the dense MLP (26B-A4B shape) + +Confirms logit parity vs HF (the block bridge defers all math to HF) and that hooks fire on +the conventional single-stream residual. +""" + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge + +MODEL_NAMES = { + "ple_kv_shared": "tiny-random/gemma-4-e", + "dense_k_eq_v": "tiny-random/gemma-4-dense", + "moe": "tiny-random/gemma-4-moe", +} +IDS = torch.tensor([[1, 2, 3, 4, 5]]) + + +@pytest.fixture(scope="module", params=list(MODEL_NAMES), ids=list(MODEL_NAMES)) +def bridge(request): + return TransformerBridge.boot_transformers( + MODEL_NAMES[request.param], device="cpu", dtype=torch.float32 + ) + + +def test_text_only_logit_parity_vs_hf(bridge): + from transformers import AutoModelForCausalLM + + hf = AutoModelForCausalLM.from_pretrained( + bridge.cfg.model_name, torch_dtype=torch.float32, attn_implementation="eager" + ).eval() + with torch.no_grad(): + ref = hf(IDS).logits + out = bridge.forward(IDS, return_type="logits") + assert out.shape == ref.shape + # PLE / KV-sharing / K==V / MoE all run inside HF — the bridge is a pass-through. + assert torch.max(torch.abs(out - ref)).item() < 1e-3 + + +def test_config_from_text_config(bridge): + # Text dims resolve from the nested text_config of the multimodal model. + assert bridge.cfg.n_layers == 4 + assert getattr(bridge.cfg, "is_multimodal", False) is False + + +def test_resid_hooks_fire_with_conventional_shape(bridge): + """The residual stream is a single conventional (batch, seq, d_model) tensor.""" + captured = {} + + def cap(tensor, hook): + captured[hook.name] = tensor.detach() + return tensor + + names = [ + n + for n in bridge.hook_dict + if n.endswith("blocks.0.hook_resid_pre") or n.endswith("blocks.0.hook_resid_post") + ] + assert names, "no residual hooks registered" + with torch.no_grad(): + bridge.run_with_hooks(IDS, fwd_hooks=[(n, cap) for n in names]) + + assert captured, "residual hooks did not fire" + for tensor in captured.values(): + assert tensor.shape == (IDS.shape[0], IDS.shape[1], bridge.cfg.d_model) + + +def test_run_with_cache_text_only(bridge): + with torch.no_grad(): + logits, cache = bridge.run_with_cache(IDS) + assert torch.isfinite(logits).all() + assert len(cache) > 0 diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py new file mode 100644 index 000000000..b0387f12e --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py @@ -0,0 +1,124 @@ +"""Unit tests for the Gemma 4 text-only architecture adapter.""" + +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + DelegatedAttentionBlockBridge, + EmbeddingBridge, + LinearBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + +ARCH = "Gemma4ForConditionalGeneration" + + +def _adapter(): + # Dimensions follow google/gemma-4-E2B's text_config. + cfg = TransformerBridgeConfig( + d_model=1536, + d_head=256, + n_heads=8, + n_layers=35, + n_ctx=131072, + d_vocab=262144, + n_key_value_heads=1, + architecture=ARCH, + ) + return ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + +def test_config_flags(): + a = _adapter() + # Text-only; PLE / layer_scalar / MoE residual topology is not fold-safe. + assert a.cfg.is_multimodal is False + assert a.supports_fold_ln is False + assert a.weight_processing_conversions == {} + assert a.cfg.normalization_type == "RMS" + # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3. + assert a.cfg.rmsnorm_uses_offset is False + assert a.cfg.positional_embedding_type == "rotary" + assert a.applicable_phases == [1, 2, 4] + + +def test_text_path_nested_under_language_model(): + m = _adapter().component_mapping + assert m["embed"].name == "model.language_model.embed_tokens" + assert m["rotary_emb"].name == "model.language_model.rotary_emb" + assert m["blocks"].name == "model.language_model.layers" + assert m["ln_final"].name == "model.language_model.norm" + assert m["unembed"].name == "lm_head" + assert isinstance(m["embed"], EmbeddingBridge) + assert isinstance(m["rotary_emb"], RotaryEmbeddingBridge) + assert isinstance(m["blocks"], DelegatedAttentionBlockBridge) + assert isinstance(m["unembed"], UnembeddingBridge) + # Vision/audio towers are referenced-but-unbridged. + assert "vision_encoder" not in m and "audio_encoder" not in m + + +def test_block_decomposition(): + blocks = _adapter().component_mapping["blocks"] + for name in ("attn", "mlp"): + assert name in blocks.submodules + # Sandwich norms (same shape as Gemma 2/3) under canonical keys. + for norm in ("ln1", "ln1_post", "ln2", "ln2_post"): + assert norm in blocks.submodules + assert blocks.submodules[norm].optional is False + + +def test_split_qkv_fork_aliases_absent(): + """Attention is delegated wholesale to HF; per-layer structure is heterogeneous + (KV-shared layers have no k/v projections), so the split-qkv fork aliases + do not apply.""" + blocks = _adapter().component_mapping["blocks"] + for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): + assert alias not in blocks.hook_aliases + # The single-stream residual aliases remain, redirected through the sandwich norms. + assert blocks.hook_aliases["hook_resid_mid"] == "ln2.hook_in" + assert blocks.hook_aliases["hook_attn_out"] == "ln1_post.hook_out" + assert blocks.hook_aliases["hook_mlp_out"] == "ln2_post.hook_out" + + +def test_kv_shared_and_k_eq_v_submodules_are_optional(): + """KV-shared layers (E2B/E4B) drop k/v proj + norms; K==V global-attention + layers (31B / 26B-A4B) drop v_proj.""" + attn = _adapter().component_mapping["blocks"].submodules["attn"] + assert attn.submodules["q"].optional is False + assert attn.submodules["o"].optional is False + assert attn.submodules["q_norm"].optional is False + for shared in ("k", "v", "k_norm", "v_norm"): + assert attn.submodules[shared].optional is True + assert isinstance(attn.submodules["q"], LinearBridge) + + +def test_per_layer_embedding_submodules_are_optional(): + """PLE modules exist only when hidden_size_per_layer_input > 0 (E2B/E4B).""" + blocks = _adapter().component_mapping["blocks"] + for name in ( + "per_layer_input_gate", + "per_layer_projection", + "post_per_layer_input_norm", + ): + assert blocks.submodules[name].optional is True + + +def test_moe_submodules_are_optional(): + """MoE branch exists only when enable_moe_block (26B-A4B).""" + blocks = _adapter().component_mapping["blocks"] + for name in ( + "router", + "experts", + "pre_feedforward_layernorm_2", + "post_feedforward_layernorm_1", + "post_feedforward_layernorm_2", + ): + assert blocks.submodules[name].optional is True + + +def test_gated_mlp_decomposition(): + mlp = _adapter().component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["gate"].name == "gate_proj" + assert mlp.submodules["in"].name == "up_proj" + assert mlp.submodules["out"].name == "down_proj" diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 49dd134f7..50b402107 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -23,6 +23,7 @@ Gemma3ArchitectureAdapter, Gemma3MultimodalArchitectureAdapter, Gemma3nArchitectureAdapter, + Gemma4ArchitectureAdapter, GPT2ArchitectureAdapter, Gpt2LmHeadCustomArchitectureAdapter, GPTBigCodeArchitectureAdapter, @@ -86,6 +87,11 @@ "Gemma3ForCausalLM": Gemma3ArchitectureAdapter, "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter, "Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter, + "Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter, + # The unified (encoder-free) 12B variant's text decoder is a strict structural + # subset of gemma4 (no PLE, no MoE — both optional in the adapter), with the + # same module paths. Requires transformers >= 5.10 to load. + "Gemma4UnifiedForConditionalGeneration": Gemma4ArchitectureAdapter, "GraniteForCausalLM": GraniteArchitectureAdapter, "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter, "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter, diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index 50d139f16..31d27dd5b 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -7,6 +7,7 @@ ) from transformer_lens.model_bridge.generalized_components.block import ( BlockBridge, + DelegatedAttentionBlockBridge, MLABlockBridge, ParallelBlockBridge, ) @@ -108,6 +109,7 @@ "AttentionBridge", "AudioFeatureExtractorBridge", "BlockBridge", + "DelegatedAttentionBlockBridge", "MLABlockBridge", "ParallelBlockBridge", "BloomBlockBridge", diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index 506107781..37851b16a 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -405,3 +405,34 @@ def __init__( if self.hook_aliases is BlockBridge.hook_aliases: self.hook_aliases = dict(self.hook_aliases) self.hook_aliases.pop("hook_resid_mid", None) + + +class DelegatedAttentionBlockBridge(BlockBridge): + """Block whose attention is delegated wholesale to HF (no split-qkv fork). + + For architectures with heterogeneous per-layer attention structure — e.g. + Gemma 4, where KV-shared layers have no ``k_proj``/``v_proj`` at all and + K==V layers have no ``v_proj`` — there is no uniform HookPoint that + represents "input that becomes Q/K/V", so the block-level ``hook_q_input``/ + ``hook_k_input``/``hook_v_input``/``hook_attn_in`` aliases do not apply. + Type-level distinction means a reader of the adapter sees + ``DelegatedAttentionBlockBridge`` and knows those hooks are absent. + """ + + def __init__( + self, + name: str, + config: Optional[Any] = None, + submodules: Optional[Dict[str, GeneralizedComponent]] = None, + hook_alias_overrides: Optional[Dict[str, str]] = None, + ): + super().__init__( + name, + config=config, + submodules=submodules, + hook_alias_overrides=hook_alias_overrides, + ) + if self.hook_aliases is BlockBridge.hook_aliases: + self.hook_aliases = dict(self.hook_aliases) + for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): + self.hook_aliases.pop(alias, None) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index a26e00004..3eafd3dc4 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -221,6 +221,10 @@ def determine_architecture_from_hf_config(hf_config): # gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration # (vision/audio referenced but unbridged in the text-only adapter). "gemma3n": "Gemma3nForConditionalGeneration", + # gemma4 is multimodal-only; all released checkpoints load as the full + # ForConditionalGeneration (vision/audio referenced but unbridged). + "gemma4": "Gemma4ForConditionalGeneration", + "gemma4_unified": "Gemma4UnifiedForConditionalGeneration", "bert": "BertForMaskedLM", "bloom": "BloomForCausalLM", "codegen": "CodeGenForCausalLM", diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 84e4584af..acc8be08b 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -42,6 +42,9 @@ from transformer_lens.model_bridge.supported_architectures.gemma3n import ( Gemma3nArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.gemma4 import ( + Gemma4ArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.gpt2 import ( GPT2ArchitectureAdapter, ) @@ -189,6 +192,7 @@ "Gemma3ArchitectureAdapter", "Gemma3nArchitectureAdapter", "Gemma3MultimodalArchitectureAdapter", + "Gemma4ArchitectureAdapter", "GraniteArchitectureAdapter", "GraniteMoeArchitectureAdapter", "GraniteMoeHybridArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/gemma4.py b/transformer_lens/model_bridge/supported_architectures/gemma4.py new file mode 100644 index 000000000..931bd2b61 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/gemma4.py @@ -0,0 +1,139 @@ +"""Gemma 4 text-only architecture adapter. + +Bridges the text path of the multimodal ``Gemma4ForConditionalGeneration`` +(``model.language_model`` + ``lm_head``); the vision/audio towers stay referenced but +unbridged. All released Gemma 4 checkpoints (E2B / E4B / 31B / 26B-A4B) ship as +``Gemma4ForConditionalGeneration``, so there is no separate text-only entry point. + +The same adapter also covers ``Gemma4UnifiedForConditionalGeneration`` (the +encoder-free 12B variant, transformers >= 5.10): its text decoder is a strict +structural subset — same module paths, no PLE and no MoE, both optional here. + +Per-layer structure is heterogeneous across the family, so all math is deferred to HF +and submodules are decomposed only for hooks (parity-safe delegation): + +- **KV sharing** (E2B/E4B): the last ``num_kv_shared_layers`` layers reuse earlier KV + states and drop their own ``k_proj`` / ``v_proj`` / ``k_norm`` / ``v_norm``. +- **K==V attention** (31B / 26B-A4B): global-attention layers share key and value + weights (``attention_k_eq_v``) and have no ``v_proj``. +- **Per-Layer Embeddings** (E2B/E4B): each layer mixes in a per-layer input via + ``per_layer_input_gate`` / ``per_layer_projection`` / ``post_per_layer_input_norm``. +- **MoE** (26B-A4B): layers add a ``router`` + batched ``experts`` block in parallel + with the dense MLP, sandwiched by three extra norms. + +Unlike Gemma 1-3, ``Gemma4RMSNorm`` multiplies by ``weight`` directly — there is no +``(1.0 + weight)`` offset. +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + DelegatedAttentionBlockBridge, + EmbeddingBridge, + LinearBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) + + +class Gemma4ArchitectureAdapter(ArchitectureAdapter): + """Text-only adapter for Gemma 4 (`Gemma4ForConditionalGeneration`).""" + + # Phase 3 (processed/compatibility mode) folds LN into a single residual stream, + # which the PLE residual mix, per-layer `layer_scalar` buffers, and the MoE branch + # can't represent. Phases 1 (HF parity), 2 (hooks), and 4 (text quality) apply. + applicable_phases: list[int] = [1, 2, 4] + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.is_multimodal = False + self.cfg.gated_mlp = True + self.cfg.uses_rms_norm = True + self.cfg.normalization_type = "RMS" + # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3. + self.cfg.rmsnorm_uses_offset = False + self.cfg.positional_embedding_type = "rotary" + self.cfg.attn_implementation = "eager" + # PLE / layer_scalar / MoE residual topology isn't fold-safe. + self.supports_fold_ln = False + self.weight_processing_conversions: dict = {} + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), + # Single rotary module serving both layer types (full / sliding) via a + # per-layer-type forward kwarg, with separate rope parameters per type. + "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), + "blocks": DelegatedAttentionBlockBridge( + name="model.language_model.layers", + submodules={ + # Sandwich norms: ln1/ln1_post around attention, ln2/ln2_post + # around the MLP (same shape as Gemma 2/3). + "ln1": GeneralizedComponent(name="input_layernorm"), + "ln1_post": GeneralizedComponent(name="post_attention_layernorm"), + "ln2": GeneralizedComponent(name="pre_feedforward_layernorm"), + "ln2_post": GeneralizedComponent(name="post_feedforward_layernorm"), + # PLE residual mix — present only when hidden_size_per_layer_input > 0 + # (E2B/E4B; absent on 31B and 26B-A4B). + "per_layer_input_gate": GeneralizedComponent( + name="per_layer_input_gate", optional=True + ), + "per_layer_projection": GeneralizedComponent( + name="per_layer_projection", optional=True + ), + "post_per_layer_input_norm": GeneralizedComponent( + name="post_per_layer_input_norm", optional=True + ), + # MoE branch — present only when enable_moe_block (26B-A4B). + "router": GeneralizedComponent(name="router", optional=True), + "experts": GeneralizedComponent(name="experts", optional=True), + "pre_feedforward_layernorm_2": GeneralizedComponent( + name="pre_feedforward_layernorm_2", optional=True + ), + "post_feedforward_layernorm_1": GeneralizedComponent( + name="post_feedforward_layernorm_1", optional=True + ), + "post_feedforward_layernorm_2": GeneralizedComponent( + name="post_feedforward_layernorm_2", optional=True + ), + "attn": GeneralizedComponent( + name="self_attn", + submodules={ + "q": LinearBridge(name="q_proj"), + # KV-shared layers (E2B/E4B) drop k/v projections and norms; + # K==V layers (31B / 26B-A4B global attention) drop v_proj. + "k": LinearBridge(name="k_proj", optional=True), + "v": LinearBridge(name="v_proj", optional=True), + "o": LinearBridge(name="o_proj"), + "q_norm": GeneralizedComponent(name="q_norm"), + "k_norm": GeneralizedComponent(name="k_norm", optional=True), + "v_norm": GeneralizedComponent(name="v_norm", optional=True), + }, + ), + "mlp": GeneralizedComponent( + name="mlp", + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": GeneralizedComponent(name="model.language_model.norm"), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Force eager attention so bridge and HF match (sliding/full layer mix).""" + if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): + hf_model.config._attn_implementation = "eager" + language_model = getattr(getattr(hf_model, "model", None), "language_model", None) + if language_model is not None and hasattr(language_model, "layers"): + for layer in language_model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): + layer.self_attn.config._attn_implementation = "eager" diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index 769b9b0d1..65d24332e 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -59,6 +59,8 @@ "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration", "Gemma3nForConditionalGeneration", + "Gemma4ForConditionalGeneration", + "Gemma4UnifiedForConditionalGeneration", "GraniteForCausalLM", "GraniteMoeForCausalLM", "GraniteMoeHybridForCausalLM", @@ -118,6 +120,8 @@ "Gemma3ForCausalLM": ["google"], "Gemma3ForConditionalGeneration": ["google"], "Gemma3nForConditionalGeneration": ["google"], + "Gemma4ForConditionalGeneration": ["google"], + "Gemma4UnifiedForConditionalGeneration": ["google"], "GemmaForCausalLM": ["google"], "GPT2LMHeadModel": ["openai-community", "stanford-crfm", "Writer"], "GPTBigCodeForCausalLM": ["bigcode"], diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index d668e14e8..15f9fb860 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -6,9 +6,9 @@ "min_downloads": 500, "scan_duration_seconds": 8.0 }, - "total_architectures": 55, - "total_models": 12112, - "total_verified": 743, + "total_architectures": 57, + "total_models": 12122, + "total_verified": 744, "models": [ { "architecture_id": "MistralForCausalLM", @@ -168137,6 +168137,146 @@ "phase4_score": null, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E2B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E2B-it", + "status": 1, + "verified_date": "2026-06-10", + "metadata": null, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": null, + "phase4_score": 94.7, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E4B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E4B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-31B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-31B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-26B-A4B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-26B-A4B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4UnifiedForConditionalGeneration", + "model_id": "google/gemma-4-12B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4UnifiedForConditionalGeneration", + "model_id": "google/gemma-4-12B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index f96594e31..5756197ab 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-06-05T13:10:29.591019", + "last_updated": "2026-06-10T14:06:20.074159", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -12360,6 +12360,16 @@ "notes": "Full verification completed", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "google/gemma-4-E2B-it", + "architecture_id": "Gemma4ForConditionalGeneration", + "verified_date": "2026-06-10", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null } ] } diff --git a/transformer_lens/tools/model_registry/generate_report.py b/transformer_lens/tools/model_registry/generate_report.py index 9349844a2..e3ead2444 100644 --- a/transformer_lens/tools/model_registry/generate_report.py +++ b/transformer_lens/tools/model_registry/generate_report.py @@ -37,6 +37,8 @@ "Gemma2ForCausalLM": "Google's Gemma 2 with improved architecture", "Gemma3ForCausalLM": "Google's Gemma 3 latest generation", "Gemma3nForConditionalGeneration": "Google's Gemma 3n efficient tri-modal model (text-only support)", + "Gemma4ForConditionalGeneration": "Google's Gemma 4 multimodal model family (text-only support)", + "Gemma4UnifiedForConditionalGeneration": "Google's Gemma 4 unified encoder-free multimodal model (text-only support)", "Qwen2ForCausalLM": "Alibaba's Qwen2 multilingual model", "Qwen3ForCausalLM": "Alibaba's Qwen3 latest generation", "Qwen3_5ForConditionalGeneration": "Alibaba's Qwen3.5 vision-language model", From 3680bbaf2d7501f8df47161eea0d3f7ea5befa4f Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 20 Jun 2026 19:49:23 +0300 Subject: [PATCH 02/10] fix: handle list eos_token_id when setting pad_token_id --- transformer_lens/benchmarks/main_benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_lens/benchmarks/main_benchmark.py b/transformer_lens/benchmarks/main_benchmark.py index 6e196ae4a..05735095f 100644 --- a/transformer_lens/benchmarks/main_benchmark.py +++ b/transformer_lens/benchmarks/main_benchmark.py @@ -941,7 +941,8 @@ def cleanup_model(model, model_name_str: str): model_name, trust_remote_code=trust_remote_code, token=_hf_token() ) if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: - hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) + eos = getattr(hf_config, "eos_token_id", None) + hf_config.pad_token_id = eos[0] if isinstance(eos, (list, tuple)) else eos hf_kwargs["config"] = hf_config if trust_remote_code: hf_kwargs["trust_remote_code"] = True From 299bda7643d4199d3d7a7d410cdf13321720f193 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 20 Jun 2026 19:49:29 +0300 Subject: [PATCH 03/10] fix: add Gemma4ForConditionalGeneration to MULTIMODAL_ARCHITECTURES --- transformer_lens/utilities/architectures.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_lens/utilities/architectures.py b/transformer_lens/utilities/architectures.py index 89440956f..3a7db5db6 100644 --- a/transformer_lens/utilities/architectures.py +++ b/transformer_lens/utilities/architectures.py @@ -33,6 +33,7 @@ "LlavaNextForConditionalGeneration", "LlavaOnevisionForConditionalGeneration", "Gemma3ForConditionalGeneration", + "Gemma4ForConditionalGeneration", "Qwen3_5ForConditionalGeneration", } From a6bdd62f5b0c306d008e60b6ca9372d7137cfee5 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 20 Jun 2026 19:49:34 +0300 Subject: [PATCH 04/10] feat: add multimodal vision support to Gemma4 adapter --- .../model_bridge/test_gemma4_bridge.py | 4 +- .../test_gemma4_adapter.py | 66 +++++++++++++++---- .../supported_architectures/gemma4.py | 46 +++++++++++-- 3 files changed, 94 insertions(+), 22 deletions(-) diff --git a/tests/integration/model_bridge/test_gemma4_bridge.py b/tests/integration/model_bridge/test_gemma4_bridge.py index a8e67dc28..11e9a61a0 100644 --- a/tests/integration/model_bridge/test_gemma4_bridge.py +++ b/tests/integration/model_bridge/test_gemma4_bridge.py @@ -1,4 +1,4 @@ -"""Integration tests for the Gemma 4 text-only TransformerBridge. +"""Integration tests for the Gemma 4 TransformerBridge. Uses tiny random-init `Gemma4ForConditionalGeneration` fixtures (4 layers, d_model 8) so CI stays light while still exercising the real per-layer heterogeneity across the family: @@ -48,7 +48,7 @@ def test_text_only_logit_parity_vs_hf(bridge): def test_config_from_text_config(bridge): # Text dims resolve from the nested text_config of the multimodal model. assert bridge.cfg.n_layers == 4 - assert getattr(bridge.cfg, "is_multimodal", False) is False + assert getattr(bridge.cfg, "is_multimodal", False) is True def test_resid_hooks_fire_with_conventional_shape(bridge): diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py index b0387f12e..c36fc7344 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py @@ -1,22 +1,25 @@ -"""Unit tests for the Gemma 4 text-only architecture adapter.""" +"""Unit tests for the Gemma 4 architecture adapter.""" -from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig -from transformer_lens.factories.architecture_adapter_factory import ( - ArchitectureAdapterFactory, -) +from types import SimpleNamespace + +from transformer_lens.config import TransformerBridgeConfig from transformer_lens.model_bridge.generalized_components import ( DelegatedAttentionBlockBridge, EmbeddingBridge, LinearBridge, RotaryEmbeddingBridge, UnembeddingBridge, + VisionProjectionBridge, +) +from transformer_lens.model_bridge.supported_architectures.gemma4 import ( + Gemma4ArchitectureAdapter, ) ARCH = "Gemma4ForConditionalGeneration" +ARCH_UNIFIED = "Gemma4UnifiedForConditionalGeneration" -def _adapter(): - # Dimensions follow google/gemma-4-E2B's text_config. +def _cfg(arch: str = ARCH, **kwargs) -> TransformerBridgeConfig: cfg = TransformerBridgeConfig( d_model=1536, d_head=256, @@ -25,15 +28,28 @@ def _adapter(): n_ctx=131072, d_vocab=262144, n_key_value_heads=1, - architecture=ARCH, + architecture=arch, + **kwargs, ) - return ArchitectureAdapterFactory.select_architecture_adapter(cfg) + # Both variants are multimodal (have vision_config + embed_vision). + cfg.vision_config = SimpleNamespace( + hidden_size=2048, + num_hidden_layers=27, + num_attention_heads=16, + ) + cfg.vision_soft_tokens_per_image = 256 + return cfg + + +def _adapter(arch: str = ARCH) -> Gemma4ArchitectureAdapter: + return Gemma4ArchitectureAdapter(_cfg(arch)) def test_config_flags(): a = _adapter() - # Text-only; PLE / layer_scalar / MoE residual topology is not fold-safe. - assert a.cfg.is_multimodal is False + # Multimodal (Gemma4ForConditionalGeneration has vision tower + projector). + assert a.cfg.is_multimodal is True + # PLE / layer_scalar / MoE residual topology is not fold-safe. assert a.supports_fold_ln is False assert a.weight_processing_conversions == {} assert a.cfg.normalization_type == "RMS" @@ -43,6 +59,16 @@ def test_config_flags(): assert a.applicable_phases == [1, 2, 4] +def test_config_flags_unified(): + """Gemma4UnifiedForConditionalGeneration (12B) is encoder-free but still multimodal: + has model.embed_vision (raw-patch projector) but no model.vision_tower.""" + a = _adapter(ARCH_UNIFIED) + assert a.cfg.is_multimodal is True + assert "vision_encoder" not in a.component_mapping + assert "vision_projector" in a.component_mapping + assert a.component_mapping["vision_projector"].name == "model.embed_vision" + + def test_text_path_nested_under_language_model(): m = _adapter().component_mapping assert m["embed"].name == "model.language_model.embed_tokens" @@ -54,8 +80,22 @@ def test_text_path_nested_under_language_model(): assert isinstance(m["rotary_emb"], RotaryEmbeddingBridge) assert isinstance(m["blocks"], DelegatedAttentionBlockBridge) assert isinstance(m["unembed"], UnembeddingBridge) - # Vision/audio towers are referenced-but-unbridged. - assert "vision_encoder" not in m and "audio_encoder" not in m + + +def test_vision_components_present_for_multimodal(): + """Gemma4ForConditionalGeneration has vision_tower + embed_vision.""" + m = _adapter().component_mapping + assert "vision_encoder" in m + assert "vision_projector" in m + assert m["vision_encoder"].name == "model.vision_tower" + assert m["vision_projector"].name == "model.embed_vision" + assert isinstance(m["vision_projector"], VisionProjectionBridge) + # Vision config fields extracted from vision_config. + a = _adapter() + assert a.cfg.vision_hidden_size == 2048 + assert a.cfg.vision_num_layers == 27 + assert a.cfg.vision_num_heads == 16 + assert a.cfg.mm_tokens_per_image == 256 def test_block_decomposition(): diff --git a/transformer_lens/model_bridge/supported_architectures/gemma4.py b/transformer_lens/model_bridge/supported_architectures/gemma4.py index 931bd2b61..5ae8ba7f1 100644 --- a/transformer_lens/model_bridge/supported_architectures/gemma4.py +++ b/transformer_lens/model_bridge/supported_architectures/gemma4.py @@ -1,13 +1,15 @@ -"""Gemma 4 text-only architecture adapter. +"""Gemma 4 architecture adapter. -Bridges the text path of the multimodal ``Gemma4ForConditionalGeneration`` -(``model.language_model`` + ``lm_head``); the vision/audio towers stay referenced but -unbridged. All released Gemma 4 checkpoints (E2B / E4B / 31B / 26B-A4B) ship as -``Gemma4ForConditionalGeneration``, so there is no separate text-only entry point. +Bridges the text path of ``Gemma4ForConditionalGeneration`` +(``model.language_model`` + ``lm_head``) and the vision pipeline. For the standard +variants (E2B / E4B / 31B / 26B-A4B) the vision encoder (``model.vision_tower``) and +projector (``model.embed_vision``) are both bridged, enabling Phase 7 multimodal testing. The same adapter also covers ``Gemma4UnifiedForConditionalGeneration`` (the encoder-free 12B variant, transformers >= 5.10): its text decoder is a strict structural subset — same module paths, no PLE and no MoE, both optional here. +It is still multimodal but has no ``vision_tower`` — ``model.embed_vision`` is the +full vision pipeline (raw-patch projection), mapped as the projector only. Per-layer structure is heterogeneous across the family, so all math is deferred to HF and submodules are decomposed only for hooks (parity-safe delegation): @@ -34,6 +36,7 @@ LinearBridge, RotaryEmbeddingBridge, UnembeddingBridge, + VisionProjectionBridge, ) from transformer_lens.model_bridge.generalized_components.base import ( GeneralizedComponent, @@ -41,7 +44,8 @@ class Gemma4ArchitectureAdapter(ArchitectureAdapter): - """Text-only adapter for Gemma 4 (`Gemma4ForConditionalGeneration`).""" + """Adapter for Gemma 4 (`Gemma4ForConditionalGeneration` — multimodal, or + `Gemma4UnifiedForConditionalGeneration` — text-only 12B).""" # Phase 3 (processed/compatibility mode) folds LN into a single residual stream, # which the PLE residual mix, per-layer `layer_scalar` buffers, and the MoE branch @@ -51,7 +55,21 @@ class Gemma4ArchitectureAdapter(ArchitectureAdapter): def __init__(self, cfg: Any) -> None: super().__init__(cfg) - self.cfg.is_multimodal = False + # Both variants are multimodal (take pixel_values). The difference: + # - Gemma4ForConditionalGeneration: vision_tower (encoder) + embed_vision (projector) + # - Gemma4UnifiedForConditionalGeneration (12B): embed_vision only — encoder-free + # embedder that does raw-patch projection without an attention-based vision encoder. + arch = getattr(cfg, "architecture", "") or "" + self._is_unified = "Gemma4Unified" in arch + self.cfg.is_multimodal = True + + if hasattr(cfg, "vision_config"): + vcfg = cfg.vision_config + self.cfg.vision_hidden_size = getattr(vcfg, "hidden_size", None) + self.cfg.vision_num_layers = getattr(vcfg, "num_hidden_layers", None) + self.cfg.vision_num_heads = getattr(vcfg, "num_attention_heads", None) + self.cfg.mm_tokens_per_image = getattr(cfg, "vision_soft_tokens_per_image", 256) + self.cfg.gated_mlp = True self.cfg.uses_rms_norm = True self.cfg.normalization_type = "RMS" @@ -63,7 +81,21 @@ def __init__(self, cfg: Any) -> None: self.supports_fold_ln = False self.weight_processing_conversions: dict = {} + # Vision components. Gemma4ForConditionalGeneration has a separate vision + # encoder (model.vision_tower) + projector (model.embed_vision). The 12B + # unified variant is encoder-free — model.embed_vision is the full vision + # pipeline (raw-patch projection), so it maps as the projector with no encoder. + _vision_mapping: dict[str, Any] = { + "vision_projector": VisionProjectionBridge(name="model.embed_vision"), + } + if not self._is_unified: + _vision_mapping = { + "vision_encoder": GeneralizedComponent(name="model.vision_tower"), + **_vision_mapping, + } + self.component_mapping = { + **_vision_mapping, "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), # Single rotary module serving both layer types (full / sliding) via a # per-layer-type forward kwarg, with separate rope parameters per type. From 4c3bad34fcf733458d2e233b1d9d25a4bbd709a5 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Mon, 22 Jun 2026 12:43:50 +0300 Subject: [PATCH 05/10] =?UTF-8?q?fix:=20use=20GeneralizedComponent=20for?= =?UTF-8?q?=20vision=20projector=20=E2=80=94=20VisionProjectionBridge=20ex?= =?UTF-8?q?pects=20positional=20'vision=5Ffeatures'=20but=20Gemma4's=20Gem?= =?UTF-8?q?ma4MultimodalEmbedder.forward()=20takes=20'inputs=5Fembeds'=20k?= =?UTF-8?q?warg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../supported_architectures/test_gemma4_adapter.py | 6 ++++-- .../model_bridge/supported_architectures/gemma4.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py index c36fc7344..afff05df4 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py @@ -9,7 +9,9 @@ LinearBridge, RotaryEmbeddingBridge, UnembeddingBridge, - VisionProjectionBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, ) from transformer_lens.model_bridge.supported_architectures.gemma4 import ( Gemma4ArchitectureAdapter, @@ -89,7 +91,7 @@ def test_vision_components_present_for_multimodal(): assert "vision_projector" in m assert m["vision_encoder"].name == "model.vision_tower" assert m["vision_projector"].name == "model.embed_vision" - assert isinstance(m["vision_projector"], VisionProjectionBridge) + assert isinstance(m["vision_projector"], GeneralizedComponent) # Vision config fields extracted from vision_config. a = _adapter() assert a.cfg.vision_hidden_size == 2048 diff --git a/transformer_lens/model_bridge/supported_architectures/gemma4.py b/transformer_lens/model_bridge/supported_architectures/gemma4.py index 5ae8ba7f1..7c301f0dd 100644 --- a/transformer_lens/model_bridge/supported_architectures/gemma4.py +++ b/transformer_lens/model_bridge/supported_architectures/gemma4.py @@ -36,7 +36,6 @@ LinearBridge, RotaryEmbeddingBridge, UnembeddingBridge, - VisionProjectionBridge, ) from transformer_lens.model_bridge.generalized_components.base import ( GeneralizedComponent, @@ -86,7 +85,7 @@ def __init__(self, cfg: Any) -> None: # unified variant is encoder-free — model.embed_vision is the full vision # pipeline (raw-patch projection), so it maps as the projector with no encoder. _vision_mapping: dict[str, Any] = { - "vision_projector": VisionProjectionBridge(name="model.embed_vision"), + "vision_projector": GeneralizedComponent(name="model.embed_vision"), } if not self._is_unified: _vision_mapping = { From d469a0590547f2d90c63b92f753965b584f24807 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Mon, 22 Jun 2026 12:56:19 +0300 Subject: [PATCH 06/10] =?UTF-8?q?fix:=20check=20image=5Ftoken=20before=20b?= =?UTF-8?q?oi=5Ftoken=20in=20multimodal=20benchmark=20=E2=80=94=20Gemma4's?= =?UTF-8?q?=20boi=5Ftoken=20is=20a=20marker,=20image=5Ftoken=20is=20the=20?= =?UTF-8?q?expandable=20placeholder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- transformer_lens/benchmarks/multimodal.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/transformer_lens/benchmarks/multimodal.py b/transformer_lens/benchmarks/multimodal.py index e18a374da..c6f4e163a 100644 --- a/transformer_lens/benchmarks/multimodal.py +++ b/transformer_lens/benchmarks/multimodal.py @@ -4,7 +4,6 @@ through forward(), generate(), and run_with_cache(). """ - import torch from transformer_lens.benchmarks.utils import ( @@ -46,8 +45,10 @@ def _prepare_test_inputs(bridge: TransformerBridge): # Different models use different tokens: # LLava: image_token = "" # Gemma3: boi_token = "" - image_token = getattr(bridge.processor, "boi_token", None) or getattr( - bridge.processor, "image_token", "" + # Gemma4: image_token is the expandable placeholder (280 tokens), + # boi_token ("<|image>") is just a marker — use image_token first. + image_token = getattr(bridge.processor, "image_token", None) or getattr( + bridge.processor, "boi_token", "" ) prompt = f"{image_token}\nDescribe this image." try: @@ -141,9 +142,9 @@ def benchmark_multimodal_forward( details={ "logits_shape": list(logits.shape), "input_ids_shape": list(input_ids.shape), - "pixel_values_shape": list(pixel_values.shape) - if pixel_values is not None - else None, + "pixel_values_shape": ( + list(pixel_values.shape) if pixel_values is not None else None + ), }, ) From 7f4b68beb34333c629e5f031cdf14cb6c3c86bed Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Mon, 22 Jun 2026 14:09:15 +0300 Subject: [PATCH 07/10] chore: update registry with Gemma4 verification results --- .../model_registry/data/supported_models.json | 36 +++++++++---------- .../data/verification_history.json | 22 +++++++++++- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index 15f9fb860..e13e870bf 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -168163,7 +168163,7 @@ "phase2_score": 100.0, "phase3_score": null, "phase4_score": 94.7, - "phase7_score": null, + "phase7_score": 100.0, "phase8_score": null }, { @@ -168183,15 +168183,15 @@ { "architecture_id": "Gemma4ForConditionalGeneration", "model_id": "google/gemma-4-E4B-it", - "status": 0, - "verified_date": null, + "status": 1, + "verified_date": "2026-06-22", "metadata": null, "note": null, - "phase1_score": null, + "phase1_score": 50.0, "phase2_score": null, "phase3_score": null, - "phase4_score": null, - "phase7_score": null, + "phase4_score": 98.5, + "phase7_score": 100.0, "phase8_score": null }, { @@ -168211,15 +168211,15 @@ { "architecture_id": "Gemma4ForConditionalGeneration", "model_id": "google/gemma-4-31B-it", - "status": 0, - "verified_date": null, + "status": 1, + "verified_date": "2026-06-22", "metadata": null, - "note": null, - "phase1_score": null, + "note": "Phase 1 requires --no-hf-reference (>80GB VRAM for 2 copies). Phase 4 text quality affected by repetition. Phase 7 multimodal requires bfloat16 dtype (fp16 overflows in vision tower).", + "phase1_score": 100.0, "phase2_score": null, "phase3_score": null, - "phase4_score": null, - "phase7_score": null, + "phase4_score": 55.1, + "phase7_score": 100.0, "phase8_score": null }, { @@ -168239,15 +168239,15 @@ { "architecture_id": "Gemma4ForConditionalGeneration", "model_id": "google/gemma-4-26B-A4B-it", - "status": 0, - "verified_date": null, + "status": 1, + "verified_date": "2026-06-22", "metadata": null, - "note": null, - "phase1_score": null, + "note": "Phase 1 requires --no-hf-reference (>80GB VRAM for 2 copies). Phase 4 text quality affected by MoE repetition. Phase 7 multimodal passes.", + "phase1_score": 100.0, "phase2_score": null, "phase3_score": null, - "phase4_score": null, - "phase7_score": null, + "phase4_score": 69.4, + "phase7_score": 100.0, "phase8_score": null }, { diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 5756197ab..3a733193d 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-06-10T14:06:20.074159", + "last_updated": "2026-06-22T10:47:55.336299", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -12370,6 +12370,26 @@ "notes": "Full verification completed", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "google/gemma-4-31B-it", + "architecture_id": "Gemma4ForConditionalGeneration", + "verified_date": "2026-06-22", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAU", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "google/gemma-4-26B-A4B-it", + "architecture_id": "Gemma4ForConditionalGeneration", + "verified_date": "2026-06-22", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAU", + "invalidated": false, + "invalidation_reason": null } ] } From b055509cd3c5cf1dbccb38c76a3c8ffc757593b2 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Mon, 22 Jun 2026 14:10:20 +0300 Subject: [PATCH 08/10] =?UTF-8?q?chore:=20fix=20E2B-it=20registry=20?= =?UTF-8?q?=E2=80=94=20P1=3D50%=20(component=20benchmark=20fails=20with=20?= =?UTF-8?q?delegation),=20P4=3D98.7%=20from=20real-model=20run?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tools/model_registry/data/supported_models.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index e13e870bf..f3bffd706 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -168156,13 +168156,13 @@ "architecture_id": "Gemma4ForConditionalGeneration", "model_id": "google/gemma-4-E2B-it", "status": 1, - "verified_date": "2026-06-10", + "verified_date": "2026-06-22", "metadata": null, "note": "Full verification completed", - "phase1_score": 100.0, + "phase1_score": 50.0, "phase2_score": 100.0, "phase3_score": null, - "phase4_score": 94.7, + "phase4_score": 98.7, "phase7_score": 100.0, "phase8_score": null }, From 3ff81b3806f9341e198e38ed3c9e1856c833bf8f Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Mon, 22 Jun 2026 17:03:01 +0300 Subject: [PATCH 09/10] =?UTF-8?q?chore:=20fix=20E2B=20phase2=5Fscore=20?= =?UTF-8?q?=E2=80=94=20set=20to=20null=20(Phase=202=20not=20run=20on=20rea?= =?UTF-8?q?l=20model)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tools/model_registry/data/supported_models.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index f3bffd706..d0380cfa4 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -168160,7 +168160,7 @@ "metadata": null, "note": "Full verification completed", "phase1_score": 50.0, - "phase2_score": 100.0, + "phase2_score": null, "phase3_score": null, "phase4_score": 98.7, "phase7_score": 100.0, From 1de51aa0b418237a0b933cc70384408d6a112435 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Mon, 22 Jun 2026 19:37:16 +0300 Subject: [PATCH 10/10] fix: guard MPS synchronize with torch.backends.mps.is_available() instead of hasattr --- transformer_lens/tools/model_registry/verify_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/tools/model_registry/verify_models.py b/transformer_lens/tools/model_registry/verify_models.py index fc75f07b0..a2a7fbdad 100644 --- a/transformer_lens/tools/model_registry/verify_models.py +++ b/transformer_lens/tools/model_registry/verify_models.py @@ -1116,7 +1116,7 @@ def verify_models( if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): + if torch.backends.mps.is_available(): torch.mps.synchronize() torch.mps.empty_cache()