Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions demos/BERT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"# Import stuff\n",
"import torch\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions demos/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1012,9 +1012,7 @@
"Mathematically, centering is a linear map, normalizing is *not* a linear map, and scaling and translation are linear maps. \n",
"* **Centering:** LayerNorm is applied every time a layer reads from the residual stream, so the mean of any residual stream vector can never matter - `center_writing_weights` set every weight matrix writing to the residual to have zero mean. \n",
"* **Normalizing:** Normalizing is not a linear map, and cannot be factored out. The `hook_scale` hook point lets you access and control for this.\n",
"* **Scaling and Translation:** Scaling and translation are linear maps, and are always followed by another linear map. The composition of two linear maps is another linear map, so we can *fold* the scaling and translation weights into the weights of the subsequent layer, and simplify things without changing the underlying computation. \n",
"\n",
"[See the docs for more details](https://github.com/TransformerLensOrg/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln)"
"* **Scaling and Translation:** Scaling and translation are linear maps, and are always followed by another linear map. The composition of two linear maps is another linear map, so we can *fold* the scaling and translation weights into the weights of the subsequent layer, and simplify things without changing the underlying computation. \n"
]
},
{
Expand Down
74 changes: 74 additions & 0 deletions tests/integration/model_bridge/test_glm4_moe_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Integration tests for the GLM-4.5 MoE TransformerBridge."""

from typing import Any

import pytest
import torch
from transformers import AutoModelForCausalLM

from transformer_lens.model_bridge.bridge import TransformerBridge
from transformer_lens.model_bridge.generalized_components import MoEBridge

MODEL_ID = "trl-internal-testing/tiny-Glm4MoeForCausalLM"


@pytest.fixture(scope="module")
def tiny_glm4_moe_bridge():
"""Load tiny GLM-4 MoE model via Hub."""

return TransformerBridge.boot_transformers(
MODEL_ID,
device="cpu",
dtype=torch.float32,
)


@pytest.fixture(scope="module")
def tiny_glm4_moe_hf() -> Any:
"""Load the raw HF model for parity checks."""
hf_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float32, attn_implementation="eager"
).eval()

# Match the bridge's eager attention path exactly.
if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
hf_model.config._attn_implementation = "eager"
if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
for layer in hf_model.model.layers:
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
layer.self_attn.config._attn_implementation = "eager"

return hf_model


@pytest.fixture(scope="module")
def tiny_tokens():
return torch.tensor([[1, 2, 3, 4]])


class TestGlm4MoeBridgeStructure:
def test_blocks_have_moe_mlp(self, tiny_glm4_moe_bridge) -> None:
assert len(tiny_glm4_moe_bridge.blocks) > 0
for i, block in enumerate(tiny_glm4_moe_bridge.blocks):
assert isinstance(block.mlp, MoEBridge), f"blocks.{i}.mlp is not MoEBridge"

def test_required_top_level_fields(self, tiny_glm4_moe_bridge) -> None:
assert hasattr(tiny_glm4_moe_bridge, "embed")
assert hasattr(tiny_glm4_moe_bridge, "ln_final")
assert hasattr(tiny_glm4_moe_bridge, "unembed")
assert tiny_glm4_moe_bridge.cfg.final_rms is True
assert tiny_glm4_moe_bridge.cfg.normalization_type == "RMS"

def test_block_attn_has_q_norm_and_k_norm(self, tiny_glm4_moe_bridge) -> None:
block = tiny_glm4_moe_bridge.blocks[0]
assert hasattr(block.attn, "q_norm")
assert hasattr(block.attn, "k_norm")


def test_forward_matches_hf(tiny_glm4_moe_bridge, tiny_glm4_moe_hf: Any, tiny_tokens) -> None:
"""Bridge logits should match HuggingFace on the tiny checkpoint."""
with torch.no_grad():
bridge_out = tiny_glm4_moe_bridge(tiny_tokens)
hf_out = tiny_glm4_moe_hf(tiny_tokens).logits.float()
max_diff = (bridge_out - hf_out).abs().max().item()
assert max_diff < 1e-4
Loading
Loading