diff --git a/tests/integration/model_bridge/test_gemma4_bridge.py b/tests/integration/model_bridge/test_gemma4_bridge.py new file mode 100644 index 000000000..11e9a61a0 --- /dev/null +++ b/tests/integration/model_bridge/test_gemma4_bridge.py @@ -0,0 +1,80 @@ +"""Integration tests for the Gemma 4 TransformerBridge. + +Uses tiny random-init `Gemma4ForConditionalGeneration` fixtures (4 layers, d_model 8) so CI +stays light while still exercising the real per-layer heterogeneity across the family: + +- ``tiny-random/gemma-4-e`` — Per-Layer Embeddings + KV-cache sharing (E2B/E4B shape) +- ``tiny-random/gemma-4-dense`` — K==V attention on global layers, no v_proj (31B shape) +- ``tiny-random/gemma-4-moe`` — router + batched experts beside the dense MLP (26B-A4B shape) + +Confirms logit parity vs HF (the block bridge defers all math to HF) and that hooks fire on +the conventional single-stream residual. +""" + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge + +MODEL_NAMES = { + "ple_kv_shared": "tiny-random/gemma-4-e", + "dense_k_eq_v": "tiny-random/gemma-4-dense", + "moe": "tiny-random/gemma-4-moe", +} +IDS = torch.tensor([[1, 2, 3, 4, 5]]) + + +@pytest.fixture(scope="module", params=list(MODEL_NAMES), ids=list(MODEL_NAMES)) +def bridge(request): + return TransformerBridge.boot_transformers( + MODEL_NAMES[request.param], device="cpu", dtype=torch.float32 + ) + + +def test_text_only_logit_parity_vs_hf(bridge): + from transformers import AutoModelForCausalLM + + hf = AutoModelForCausalLM.from_pretrained( + bridge.cfg.model_name, torch_dtype=torch.float32, attn_implementation="eager" + ).eval() + with torch.no_grad(): + ref = hf(IDS).logits + out = bridge.forward(IDS, return_type="logits") + assert out.shape == ref.shape + # PLE / KV-sharing / K==V / MoE all run inside HF — the bridge is a pass-through. + assert torch.max(torch.abs(out - ref)).item() < 1e-3 + + +def test_config_from_text_config(bridge): + # Text dims resolve from the nested text_config of the multimodal model. + assert bridge.cfg.n_layers == 4 + assert getattr(bridge.cfg, "is_multimodal", False) is True + + +def test_resid_hooks_fire_with_conventional_shape(bridge): + """The residual stream is a single conventional (batch, seq, d_model) tensor.""" + captured = {} + + def cap(tensor, hook): + captured[hook.name] = tensor.detach() + return tensor + + names = [ + n + for n in bridge.hook_dict + if n.endswith("blocks.0.hook_resid_pre") or n.endswith("blocks.0.hook_resid_post") + ] + assert names, "no residual hooks registered" + with torch.no_grad(): + bridge.run_with_hooks(IDS, fwd_hooks=[(n, cap) for n in names]) + + assert captured, "residual hooks did not fire" + for tensor in captured.values(): + assert tensor.shape == (IDS.shape[0], IDS.shape[1], bridge.cfg.d_model) + + +def test_run_with_cache_text_only(bridge): + with torch.no_grad(): + logits, cache = bridge.run_with_cache(IDS) + assert torch.isfinite(logits).all() + assert len(cache) > 0 diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py new file mode 100644 index 000000000..afff05df4 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py @@ -0,0 +1,166 @@ +"""Unit tests for the Gemma 4 architecture adapter.""" + +from types import SimpleNamespace + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.generalized_components import ( + DelegatedAttentionBlockBridge, + EmbeddingBridge, + LinearBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) +from transformer_lens.model_bridge.supported_architectures.gemma4 import ( + Gemma4ArchitectureAdapter, +) + +ARCH = "Gemma4ForConditionalGeneration" +ARCH_UNIFIED = "Gemma4UnifiedForConditionalGeneration" + + +def _cfg(arch: str = ARCH, **kwargs) -> TransformerBridgeConfig: + cfg = TransformerBridgeConfig( + d_model=1536, + d_head=256, + n_heads=8, + n_layers=35, + n_ctx=131072, + d_vocab=262144, + n_key_value_heads=1, + architecture=arch, + **kwargs, + ) + # Both variants are multimodal (have vision_config + embed_vision). + cfg.vision_config = SimpleNamespace( + hidden_size=2048, + num_hidden_layers=27, + num_attention_heads=16, + ) + cfg.vision_soft_tokens_per_image = 256 + return cfg + + +def _adapter(arch: str = ARCH) -> Gemma4ArchitectureAdapter: + return Gemma4ArchitectureAdapter(_cfg(arch)) + + +def test_config_flags(): + a = _adapter() + # Multimodal (Gemma4ForConditionalGeneration has vision tower + projector). + assert a.cfg.is_multimodal is True + # PLE / layer_scalar / MoE residual topology is not fold-safe. + assert a.supports_fold_ln is False + assert a.weight_processing_conversions == {} + assert a.cfg.normalization_type == "RMS" + # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3. + assert a.cfg.rmsnorm_uses_offset is False + assert a.cfg.positional_embedding_type == "rotary" + assert a.applicable_phases == [1, 2, 4] + + +def test_config_flags_unified(): + """Gemma4UnifiedForConditionalGeneration (12B) is encoder-free but still multimodal: + has model.embed_vision (raw-patch projector) but no model.vision_tower.""" + a = _adapter(ARCH_UNIFIED) + assert a.cfg.is_multimodal is True + assert "vision_encoder" not in a.component_mapping + assert "vision_projector" in a.component_mapping + assert a.component_mapping["vision_projector"].name == "model.embed_vision" + + +def test_text_path_nested_under_language_model(): + m = _adapter().component_mapping + assert m["embed"].name == "model.language_model.embed_tokens" + assert m["rotary_emb"].name == "model.language_model.rotary_emb" + assert m["blocks"].name == "model.language_model.layers" + assert m["ln_final"].name == "model.language_model.norm" + assert m["unembed"].name == "lm_head" + assert isinstance(m["embed"], EmbeddingBridge) + assert isinstance(m["rotary_emb"], RotaryEmbeddingBridge) + assert isinstance(m["blocks"], DelegatedAttentionBlockBridge) + assert isinstance(m["unembed"], UnembeddingBridge) + + +def test_vision_components_present_for_multimodal(): + """Gemma4ForConditionalGeneration has vision_tower + embed_vision.""" + m = _adapter().component_mapping + assert "vision_encoder" in m + assert "vision_projector" in m + assert m["vision_encoder"].name == "model.vision_tower" + assert m["vision_projector"].name == "model.embed_vision" + assert isinstance(m["vision_projector"], GeneralizedComponent) + # Vision config fields extracted from vision_config. + a = _adapter() + assert a.cfg.vision_hidden_size == 2048 + assert a.cfg.vision_num_layers == 27 + assert a.cfg.vision_num_heads == 16 + assert a.cfg.mm_tokens_per_image == 256 + + +def test_block_decomposition(): + blocks = _adapter().component_mapping["blocks"] + for name in ("attn", "mlp"): + assert name in blocks.submodules + # Sandwich norms (same shape as Gemma 2/3) under canonical keys. + for norm in ("ln1", "ln1_post", "ln2", "ln2_post"): + assert norm in blocks.submodules + assert blocks.submodules[norm].optional is False + + +def test_split_qkv_fork_aliases_absent(): + """Attention is delegated wholesale to HF; per-layer structure is heterogeneous + (KV-shared layers have no k/v projections), so the split-qkv fork aliases + do not apply.""" + blocks = _adapter().component_mapping["blocks"] + for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): + assert alias not in blocks.hook_aliases + # The single-stream residual aliases remain, redirected through the sandwich norms. + assert blocks.hook_aliases["hook_resid_mid"] == "ln2.hook_in" + assert blocks.hook_aliases["hook_attn_out"] == "ln1_post.hook_out" + assert blocks.hook_aliases["hook_mlp_out"] == "ln2_post.hook_out" + + +def test_kv_shared_and_k_eq_v_submodules_are_optional(): + """KV-shared layers (E2B/E4B) drop k/v proj + norms; K==V global-attention + layers (31B / 26B-A4B) drop v_proj.""" + attn = _adapter().component_mapping["blocks"].submodules["attn"] + assert attn.submodules["q"].optional is False + assert attn.submodules["o"].optional is False + assert attn.submodules["q_norm"].optional is False + for shared in ("k", "v", "k_norm", "v_norm"): + assert attn.submodules[shared].optional is True + assert isinstance(attn.submodules["q"], LinearBridge) + + +def test_per_layer_embedding_submodules_are_optional(): + """PLE modules exist only when hidden_size_per_layer_input > 0 (E2B/E4B).""" + blocks = _adapter().component_mapping["blocks"] + for name in ( + "per_layer_input_gate", + "per_layer_projection", + "post_per_layer_input_norm", + ): + assert blocks.submodules[name].optional is True + + +def test_moe_submodules_are_optional(): + """MoE branch exists only when enable_moe_block (26B-A4B).""" + blocks = _adapter().component_mapping["blocks"] + for name in ( + "router", + "experts", + "pre_feedforward_layernorm_2", + "post_feedforward_layernorm_1", + "post_feedforward_layernorm_2", + ): + assert blocks.submodules[name].optional is True + + +def test_gated_mlp_decomposition(): + mlp = _adapter().component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["gate"].name == "gate_proj" + assert mlp.submodules["in"].name == "up_proj" + assert mlp.submodules["out"].name == "down_proj" diff --git a/transformer_lens/benchmarks/main_benchmark.py b/transformer_lens/benchmarks/main_benchmark.py index 6e196ae4a..05735095f 100644 --- a/transformer_lens/benchmarks/main_benchmark.py +++ b/transformer_lens/benchmarks/main_benchmark.py @@ -941,7 +941,8 @@ def cleanup_model(model, model_name_str: str): model_name, trust_remote_code=trust_remote_code, token=_hf_token() ) if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: - hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) + eos = getattr(hf_config, "eos_token_id", None) + hf_config.pad_token_id = eos[0] if isinstance(eos, (list, tuple)) else eos hf_kwargs["config"] = hf_config if trust_remote_code: hf_kwargs["trust_remote_code"] = True diff --git a/transformer_lens/benchmarks/multimodal.py b/transformer_lens/benchmarks/multimodal.py index e18a374da..c6f4e163a 100644 --- a/transformer_lens/benchmarks/multimodal.py +++ b/transformer_lens/benchmarks/multimodal.py @@ -4,7 +4,6 @@ through forward(), generate(), and run_with_cache(). """ - import torch from transformer_lens.benchmarks.utils import ( @@ -46,8 +45,10 @@ def _prepare_test_inputs(bridge: TransformerBridge): # Different models use different tokens: # LLava: image_token = "" # Gemma3: boi_token = "" - image_token = getattr(bridge.processor, "boi_token", None) or getattr( - bridge.processor, "image_token", "" + # Gemma4: image_token is the expandable placeholder (280 tokens), + # boi_token ("<|image>") is just a marker — use image_token first. + image_token = getattr(bridge.processor, "image_token", None) or getattr( + bridge.processor, "boi_token", "" ) prompt = f"{image_token}\nDescribe this image." try: @@ -141,9 +142,9 @@ def benchmark_multimodal_forward( details={ "logits_shape": list(logits.shape), "input_ids_shape": list(input_ids.shape), - "pixel_values_shape": list(pixel_values.shape) - if pixel_values is not None - else None, + "pixel_values_shape": ( + list(pixel_values.shape) if pixel_values is not None else None + ), }, ) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 49dd134f7..50b402107 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -23,6 +23,7 @@ Gemma3ArchitectureAdapter, Gemma3MultimodalArchitectureAdapter, Gemma3nArchitectureAdapter, + Gemma4ArchitectureAdapter, GPT2ArchitectureAdapter, Gpt2LmHeadCustomArchitectureAdapter, GPTBigCodeArchitectureAdapter, @@ -86,6 +87,11 @@ "Gemma3ForCausalLM": Gemma3ArchitectureAdapter, "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter, "Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter, + "Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter, + # The unified (encoder-free) 12B variant's text decoder is a strict structural + # subset of gemma4 (no PLE, no MoE — both optional in the adapter), with the + # same module paths. Requires transformers >= 5.10 to load. + "Gemma4UnifiedForConditionalGeneration": Gemma4ArchitectureAdapter, "GraniteForCausalLM": GraniteArchitectureAdapter, "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter, "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter, diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index 50d139f16..31d27dd5b 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -7,6 +7,7 @@ ) from transformer_lens.model_bridge.generalized_components.block import ( BlockBridge, + DelegatedAttentionBlockBridge, MLABlockBridge, ParallelBlockBridge, ) @@ -108,6 +109,7 @@ "AttentionBridge", "AudioFeatureExtractorBridge", "BlockBridge", + "DelegatedAttentionBlockBridge", "MLABlockBridge", "ParallelBlockBridge", "BloomBlockBridge", diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index 506107781..37851b16a 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -405,3 +405,34 @@ def __init__( if self.hook_aliases is BlockBridge.hook_aliases: self.hook_aliases = dict(self.hook_aliases) self.hook_aliases.pop("hook_resid_mid", None) + + +class DelegatedAttentionBlockBridge(BlockBridge): + """Block whose attention is delegated wholesale to HF (no split-qkv fork). + + For architectures with heterogeneous per-layer attention structure — e.g. + Gemma 4, where KV-shared layers have no ``k_proj``/``v_proj`` at all and + K==V layers have no ``v_proj`` — there is no uniform HookPoint that + represents "input that becomes Q/K/V", so the block-level ``hook_q_input``/ + ``hook_k_input``/``hook_v_input``/``hook_attn_in`` aliases do not apply. + Type-level distinction means a reader of the adapter sees + ``DelegatedAttentionBlockBridge`` and knows those hooks are absent. + """ + + def __init__( + self, + name: str, + config: Optional[Any] = None, + submodules: Optional[Dict[str, GeneralizedComponent]] = None, + hook_alias_overrides: Optional[Dict[str, str]] = None, + ): + super().__init__( + name, + config=config, + submodules=submodules, + hook_alias_overrides=hook_alias_overrides, + ) + if self.hook_aliases is BlockBridge.hook_aliases: + self.hook_aliases = dict(self.hook_aliases) + for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): + self.hook_aliases.pop(alias, None) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index a26e00004..3eafd3dc4 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -221,6 +221,10 @@ def determine_architecture_from_hf_config(hf_config): # gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration # (vision/audio referenced but unbridged in the text-only adapter). "gemma3n": "Gemma3nForConditionalGeneration", + # gemma4 is multimodal-only; all released checkpoints load as the full + # ForConditionalGeneration (vision/audio referenced but unbridged). + "gemma4": "Gemma4ForConditionalGeneration", + "gemma4_unified": "Gemma4UnifiedForConditionalGeneration", "bert": "BertForMaskedLM", "bloom": "BloomForCausalLM", "codegen": "CodeGenForCausalLM", diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 84e4584af..acc8be08b 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -42,6 +42,9 @@ from transformer_lens.model_bridge.supported_architectures.gemma3n import ( Gemma3nArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.gemma4 import ( + Gemma4ArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.gpt2 import ( GPT2ArchitectureAdapter, ) @@ -189,6 +192,7 @@ "Gemma3ArchitectureAdapter", "Gemma3nArchitectureAdapter", "Gemma3MultimodalArchitectureAdapter", + "Gemma4ArchitectureAdapter", "GraniteArchitectureAdapter", "GraniteMoeArchitectureAdapter", "GraniteMoeHybridArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/gemma4.py b/transformer_lens/model_bridge/supported_architectures/gemma4.py new file mode 100644 index 000000000..7c301f0dd --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/gemma4.py @@ -0,0 +1,170 @@ +"""Gemma 4 architecture adapter. + +Bridges the text path of ``Gemma4ForConditionalGeneration`` +(``model.language_model`` + ``lm_head``) and the vision pipeline. For the standard +variants (E2B / E4B / 31B / 26B-A4B) the vision encoder (``model.vision_tower``) and +projector (``model.embed_vision``) are both bridged, enabling Phase 7 multimodal testing. + +The same adapter also covers ``Gemma4UnifiedForConditionalGeneration`` (the +encoder-free 12B variant, transformers >= 5.10): its text decoder is a strict +structural subset — same module paths, no PLE and no MoE, both optional here. +It is still multimodal but has no ``vision_tower`` — ``model.embed_vision`` is the +full vision pipeline (raw-patch projection), mapped as the projector only. + +Per-layer structure is heterogeneous across the family, so all math is deferred to HF +and submodules are decomposed only for hooks (parity-safe delegation): + +- **KV sharing** (E2B/E4B): the last ``num_kv_shared_layers`` layers reuse earlier KV + states and drop their own ``k_proj`` / ``v_proj`` / ``k_norm`` / ``v_norm``. +- **K==V attention** (31B / 26B-A4B): global-attention layers share key and value + weights (``attention_k_eq_v``) and have no ``v_proj``. +- **Per-Layer Embeddings** (E2B/E4B): each layer mixes in a per-layer input via + ``per_layer_input_gate`` / ``per_layer_projection`` / ``post_per_layer_input_norm``. +- **MoE** (26B-A4B): layers add a ``router`` + batched ``experts`` block in parallel + with the dense MLP, sandwiched by three extra norms. + +Unlike Gemma 1-3, ``Gemma4RMSNorm`` multiplies by ``weight`` directly — there is no +``(1.0 + weight)`` offset. +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + DelegatedAttentionBlockBridge, + EmbeddingBridge, + LinearBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) + + +class Gemma4ArchitectureAdapter(ArchitectureAdapter): + """Adapter for Gemma 4 (`Gemma4ForConditionalGeneration` — multimodal, or + `Gemma4UnifiedForConditionalGeneration` — text-only 12B).""" + + # Phase 3 (processed/compatibility mode) folds LN into a single residual stream, + # which the PLE residual mix, per-layer `layer_scalar` buffers, and the MoE branch + # can't represent. Phases 1 (HF parity), 2 (hooks), and 4 (text quality) apply. + applicable_phases: list[int] = [1, 2, 4] + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + # Both variants are multimodal (take pixel_values). The difference: + # - Gemma4ForConditionalGeneration: vision_tower (encoder) + embed_vision (projector) + # - Gemma4UnifiedForConditionalGeneration (12B): embed_vision only — encoder-free + # embedder that does raw-patch projection without an attention-based vision encoder. + arch = getattr(cfg, "architecture", "") or "" + self._is_unified = "Gemma4Unified" in arch + self.cfg.is_multimodal = True + + if hasattr(cfg, "vision_config"): + vcfg = cfg.vision_config + self.cfg.vision_hidden_size = getattr(vcfg, "hidden_size", None) + self.cfg.vision_num_layers = getattr(vcfg, "num_hidden_layers", None) + self.cfg.vision_num_heads = getattr(vcfg, "num_attention_heads", None) + self.cfg.mm_tokens_per_image = getattr(cfg, "vision_soft_tokens_per_image", 256) + + self.cfg.gated_mlp = True + self.cfg.uses_rms_norm = True + self.cfg.normalization_type = "RMS" + # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3. + self.cfg.rmsnorm_uses_offset = False + self.cfg.positional_embedding_type = "rotary" + self.cfg.attn_implementation = "eager" + # PLE / layer_scalar / MoE residual topology isn't fold-safe. + self.supports_fold_ln = False + self.weight_processing_conversions: dict = {} + + # Vision components. Gemma4ForConditionalGeneration has a separate vision + # encoder (model.vision_tower) + projector (model.embed_vision). The 12B + # unified variant is encoder-free — model.embed_vision is the full vision + # pipeline (raw-patch projection), so it maps as the projector with no encoder. + _vision_mapping: dict[str, Any] = { + "vision_projector": GeneralizedComponent(name="model.embed_vision"), + } + if not self._is_unified: + _vision_mapping = { + "vision_encoder": GeneralizedComponent(name="model.vision_tower"), + **_vision_mapping, + } + + self.component_mapping = { + **_vision_mapping, + "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), + # Single rotary module serving both layer types (full / sliding) via a + # per-layer-type forward kwarg, with separate rope parameters per type. + "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), + "blocks": DelegatedAttentionBlockBridge( + name="model.language_model.layers", + submodules={ + # Sandwich norms: ln1/ln1_post around attention, ln2/ln2_post + # around the MLP (same shape as Gemma 2/3). + "ln1": GeneralizedComponent(name="input_layernorm"), + "ln1_post": GeneralizedComponent(name="post_attention_layernorm"), + "ln2": GeneralizedComponent(name="pre_feedforward_layernorm"), + "ln2_post": GeneralizedComponent(name="post_feedforward_layernorm"), + # PLE residual mix — present only when hidden_size_per_layer_input > 0 + # (E2B/E4B; absent on 31B and 26B-A4B). + "per_layer_input_gate": GeneralizedComponent( + name="per_layer_input_gate", optional=True + ), + "per_layer_projection": GeneralizedComponent( + name="per_layer_projection", optional=True + ), + "post_per_layer_input_norm": GeneralizedComponent( + name="post_per_layer_input_norm", optional=True + ), + # MoE branch — present only when enable_moe_block (26B-A4B). + "router": GeneralizedComponent(name="router", optional=True), + "experts": GeneralizedComponent(name="experts", optional=True), + "pre_feedforward_layernorm_2": GeneralizedComponent( + name="pre_feedforward_layernorm_2", optional=True + ), + "post_feedforward_layernorm_1": GeneralizedComponent( + name="post_feedforward_layernorm_1", optional=True + ), + "post_feedforward_layernorm_2": GeneralizedComponent( + name="post_feedforward_layernorm_2", optional=True + ), + "attn": GeneralizedComponent( + name="self_attn", + submodules={ + "q": LinearBridge(name="q_proj"), + # KV-shared layers (E2B/E4B) drop k/v projections and norms; + # K==V layers (31B / 26B-A4B global attention) drop v_proj. + "k": LinearBridge(name="k_proj", optional=True), + "v": LinearBridge(name="v_proj", optional=True), + "o": LinearBridge(name="o_proj"), + "q_norm": GeneralizedComponent(name="q_norm"), + "k_norm": GeneralizedComponent(name="k_norm", optional=True), + "v_norm": GeneralizedComponent(name="v_norm", optional=True), + }, + ), + "mlp": GeneralizedComponent( + name="mlp", + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": GeneralizedComponent(name="model.language_model.norm"), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Force eager attention so bridge and HF match (sliding/full layer mix).""" + if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): + hf_model.config._attn_implementation = "eager" + language_model = getattr(getattr(hf_model, "model", None), "language_model", None) + if language_model is not None and hasattr(language_model, "layers"): + for layer in language_model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): + layer.self_attn.config._attn_implementation = "eager" diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index 769b9b0d1..65d24332e 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -59,6 +59,8 @@ "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration", "Gemma3nForConditionalGeneration", + "Gemma4ForConditionalGeneration", + "Gemma4UnifiedForConditionalGeneration", "GraniteForCausalLM", "GraniteMoeForCausalLM", "GraniteMoeHybridForCausalLM", @@ -118,6 +120,8 @@ "Gemma3ForCausalLM": ["google"], "Gemma3ForConditionalGeneration": ["google"], "Gemma3nForConditionalGeneration": ["google"], + "Gemma4ForConditionalGeneration": ["google"], + "Gemma4UnifiedForConditionalGeneration": ["google"], "GemmaForCausalLM": ["google"], "GPT2LMHeadModel": ["openai-community", "stanford-crfm", "Writer"], "GPTBigCodeForCausalLM": ["bigcode"], diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index d668e14e8..d0380cfa4 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -6,9 +6,9 @@ "min_downloads": 500, "scan_duration_seconds": 8.0 }, - "total_architectures": 55, - "total_models": 12112, - "total_verified": 743, + "total_architectures": 57, + "total_models": 12122, + "total_verified": 744, "models": [ { "architecture_id": "MistralForCausalLM", @@ -168137,6 +168137,146 @@ "phase4_score": null, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E2B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E2B-it", + "status": 1, + "verified_date": "2026-06-22", + "metadata": null, + "note": "Full verification completed", + "phase1_score": 50.0, + "phase2_score": null, + "phase3_score": null, + "phase4_score": 98.7, + "phase7_score": 100.0, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E4B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E4B-it", + "status": 1, + "verified_date": "2026-06-22", + "metadata": null, + "note": null, + "phase1_score": 50.0, + "phase2_score": null, + "phase3_score": null, + "phase4_score": 98.5, + "phase7_score": 100.0, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-31B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-31B-it", + "status": 1, + "verified_date": "2026-06-22", + "metadata": null, + "note": "Phase 1 requires --no-hf-reference (>80GB VRAM for 2 copies). Phase 4 text quality affected by repetition. Phase 7 multimodal requires bfloat16 dtype (fp16 overflows in vision tower).", + "phase1_score": 100.0, + "phase2_score": null, + "phase3_score": null, + "phase4_score": 55.1, + "phase7_score": 100.0, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-26B-A4B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-26B-A4B-it", + "status": 1, + "verified_date": "2026-06-22", + "metadata": null, + "note": "Phase 1 requires --no-hf-reference (>80GB VRAM for 2 copies). Phase 4 text quality affected by MoE repetition. Phase 7 multimodal passes.", + "phase1_score": 100.0, + "phase2_score": null, + "phase3_score": null, + "phase4_score": 69.4, + "phase7_score": 100.0, + "phase8_score": null + }, + { + "architecture_id": "Gemma4UnifiedForConditionalGeneration", + "model_id": "google/gemma-4-12B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4UnifiedForConditionalGeneration", + "model_id": "google/gemma-4-12B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index f96594e31..3a733193d 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-06-05T13:10:29.591019", + "last_updated": "2026-06-22T10:47:55.336299", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -12360,6 +12360,36 @@ "notes": "Full verification completed", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "google/gemma-4-E2B-it", + "architecture_id": "Gemma4ForConditionalGeneration", + "verified_date": "2026-06-10", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "google/gemma-4-31B-it", + "architecture_id": "Gemma4ForConditionalGeneration", + "verified_date": "2026-06-22", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAU", + "invalidated": false, + "invalidation_reason": null + }, + { + "model_id": "google/gemma-4-26B-A4B-it", + "architecture_id": "Gemma4ForConditionalGeneration", + "verified_date": "2026-06-22", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAU", + "invalidated": false, + "invalidation_reason": null } ] } diff --git a/transformer_lens/tools/model_registry/generate_report.py b/transformer_lens/tools/model_registry/generate_report.py index 9349844a2..e3ead2444 100644 --- a/transformer_lens/tools/model_registry/generate_report.py +++ b/transformer_lens/tools/model_registry/generate_report.py @@ -37,6 +37,8 @@ "Gemma2ForCausalLM": "Google's Gemma 2 with improved architecture", "Gemma3ForCausalLM": "Google's Gemma 3 latest generation", "Gemma3nForConditionalGeneration": "Google's Gemma 3n efficient tri-modal model (text-only support)", + "Gemma4ForConditionalGeneration": "Google's Gemma 4 multimodal model family (text-only support)", + "Gemma4UnifiedForConditionalGeneration": "Google's Gemma 4 unified encoder-free multimodal model (text-only support)", "Qwen2ForCausalLM": "Alibaba's Qwen2 multilingual model", "Qwen3ForCausalLM": "Alibaba's Qwen3 latest generation", "Qwen3_5ForConditionalGeneration": "Alibaba's Qwen3.5 vision-language model", diff --git a/transformer_lens/tools/model_registry/verify_models.py b/transformer_lens/tools/model_registry/verify_models.py index fc75f07b0..a2a7fbdad 100644 --- a/transformer_lens/tools/model_registry/verify_models.py +++ b/transformer_lens/tools/model_registry/verify_models.py @@ -1116,7 +1116,7 @@ def verify_models( if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): + if torch.backends.mps.is_available(): torch.mps.synchronize() torch.mps.empty_cache() diff --git a/transformer_lens/utilities/architectures.py b/transformer_lens/utilities/architectures.py index 89440956f..3a7db5db6 100644 --- a/transformer_lens/utilities/architectures.py +++ b/transformer_lens/utilities/architectures.py @@ -33,6 +33,7 @@ "LlavaNextForConditionalGeneration", "LlavaOnevisionForConditionalGeneration", "Gemma3ForConditionalGeneration", + "Gemma4ForConditionalGeneration", "Qwen3_5ForConditionalGeneration", }