Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions tests/unit/utilities/test_typed_module_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Unit tests for TypedModuleList.

TypedModuleList is a generic, type-preserving wrapper around nn.ModuleList. These
tests pin down its runtime behaviour: it must stay a drop-in nn.ModuleList (same
iteration, indexing, slicing, and submodule registration). The static typing
guarantee (iteration/indexing yield the element type) is exercised separately by
mypy over the rest of the package.
"""

import pytest
import torch.nn as nn

from transformer_lens.utilities import TypedModuleList
from transformer_lens.utilities.typed_module_list import (
TypedModuleList as TypedModuleListDirect,
)


def _linears(n: int) -> list[nn.Linear]:
return [nn.Linear(3, 3) for _ in range(n)]


class TestTypedModuleList:
"""Runtime behaviour of TypedModuleList."""

def test_exported_from_utilities_package(self):
"""It is re-exported from transformer_lens.utilities."""
assert TypedModuleList is TypedModuleListDirect

def test_is_an_nn_module_list(self):
"""It must remain a genuine nn.ModuleList subclass."""
assert isinstance(TypedModuleList(_linears(2)), nn.ModuleList)

@pytest.mark.parametrize("n", [0, 1, 3])
def test_construction_and_len(self, n: int):
"""len() matches the number of modules passed in (including empty)."""
assert len(TypedModuleList(_linears(n))) == n

def test_construction_with_no_arguments(self):
"""It can be constructed empty, like nn.ModuleList()."""
empty = TypedModuleList()
assert len(empty) == 0
assert list(empty) == []

def test_iteration_preserves_order_and_identity(self):
"""Iterating yields the exact module objects, in order."""
layers = _linears(3)
assert [block for block in TypedModuleList(layers)] == layers

def test_integer_indexing_returns_the_element(self):
"""Integer indexing (including negative) returns the stored module."""
layers = _linears(3)
tml = TypedModuleList(layers)
assert tml[0] is layers[0]
assert tml[-1] is layers[-1]

def test_slicing_returns_a_typed_module_list(self):
"""Slicing returns a TypedModuleList (not a bare nn.ModuleList) with the right modules."""
layers = _linears(4)
sliced = TypedModuleList(layers)[1:3]
assert isinstance(sliced, TypedModuleList)
assert list(sliced) == layers[1:3]

def test_modules_are_registered_as_submodules(self):
"""Child modules are registered, so parameters() / named_children() work."""
tml = TypedModuleList(_linears(2))
# Two Linear(3, 3) layers => 2 * (weight + bias) == 4 parameter tensors.
assert len(list(tml.parameters())) == 4
assert [name for name, _ in tml.named_children()] == ["0", "1"]

def test_registers_correctly_when_nested_in_a_module(self):
"""Used as a submodule, its parameters surface on the parent under the attribute name."""

class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.blocks = TypedModuleList(_linears(2))

net = Net()
param_names = [name for name, _ in net.named_parameters()]
assert len(param_names) == 4
assert all(name.startswith("blocks.") for name in param_names)

def test_append_mutates_and_returns_self(self):
"""append adds the module and returns the same list (for chaining), like nn.ModuleList."""
tml: TypedModuleList[nn.Linear] = TypedModuleList()
layer = nn.Linear(3, 3)
returned = tml.append(layer)
assert returned is tml
assert isinstance(returned, TypedModuleList)
assert len(tml) == 1
assert tml[0] is layer

def test_setitem_replaces_element(self):
"""__setitem__ replaces the module at an existing index and re-registers it."""
layers = _linears(2)
tml = TypedModuleList(layers)
replacement = nn.Linear(3, 3)
tml[0] = replacement
assert tml[0] is replacement
assert tml[1] is layers[1]
assert len(tml) == 2
# The replacement must be registered (old child gone, new child present).
assert replacement in tml.children()
assert layers[0] not in tml.children()
10 changes: 3 additions & 7 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class first, including the examples, and then skimming the available methods. Yo
from transformer_lens.utilities import Slice, SliceInput, warn_if_mps

if TYPE_CHECKING:
from transformer_lens.components import TransformerBlock
from transformer_lens.HookedTransformer import HookedTransformer


Expand Down Expand Up @@ -748,8 +747,7 @@ def compute_head_results(
)

# Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model])
# nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T.
block = cast("TransformerBlock", self.model.blocks[layer])
block = self.model.blocks[layer]
result = z * block.attn.W_O

# Sum over d_head to get the contribution of each head to the residual stream
Expand Down Expand Up @@ -906,8 +904,7 @@ def get_neuron_results(
pos_slice = Slice(pos_slice)

neuron_acts = self[("post", layer, "mlp")]
# ModuleList[T] indexing is typed `Tensor | Module` upstream; cast restores T.
block = cast("TransformerBlock", self.model.blocks[layer])
block = self.model.blocks[layer]
W_out = block.mlp.W_out
if pos_slice is not None:
# Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures
Expand Down Expand Up @@ -974,8 +971,7 @@ def _stack_neuron_results_apply_ln_projected(

components: list = []
for l in range(layer):
# nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T.
block = cast("TransformerBlock", self.model.blocks[l])
block = self.model.blocks[l]
W_out_l = block.mlp.W_out # [d_mlp, d_model]
W_out_l_sliced = neuron_slice.apply(W_out_l, dim=0)
W_proj_l = W_out_l_sliced @ project_2d # [d_mlp, n_outs]
Expand Down
38 changes: 7 additions & 31 deletions transformer_lens/HookedAudioEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast, overload

import numpy as np
import torch
import torch.nn as nn
from einops import repeat
from jaxtyping import Float, Int
from transformers import AutoFeatureExtractor, HubertModel, Wav2Vec2Model
from typing_extensions import Literal

from transformer_lens import loading_from_pretrained as loading
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.components import MLP, Attention, BertBlock
from transformer_lens.components import MLP, BertBlock
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.HookedRootModule import HookedRootModule
from transformer_lens.utilities import devices
from transformer_lens.utilities import TypedModuleList, devices

T = TypeVar("T", bound="HookedAudioEncoder")

Expand All @@ -41,6 +40,7 @@ class HookedAudioEncoder(HookedRootModule):

processor: Any # AutoFeatureExtractor — HF auto class, not typed as callable in stubs
hubert_model: Union[HubertModel, Wav2Vec2Model]
blocks: TypedModuleList[BertBlock]

def __init__(
self,
Expand All @@ -60,7 +60,7 @@ def __init__(

assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder"

self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])
self.blocks = TypedModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])

if move_to_device:
if self.cfg.device is None:
Expand Down Expand Up @@ -426,86 +426,62 @@ def from_pretrained(
@property
def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the key weights across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_K for block in self.blocks], dim=0)

@property
def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the query weights across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_Q for block in self.blocks], dim=0)

@property
def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the value weights across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_V for block in self.blocks], dim=0)

@property
def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
"""Stacks the attn output weights across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_O for block in self.blocks], dim=0)

@property
def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
"""Stacks the MLP input weights across all layers"""
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.W_in for block in self.blocks], dim=0)

@property
def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
"""Stacks the MLP output weights across all layers"""
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.W_out for block in self.blocks], dim=0)

@property
def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the key biases across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_K for block in self.blocks], dim=0)

@property
def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the query biases across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_Q for block in self.blocks], dim=0)

@property
def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the value biases across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_V for block in self.blocks], dim=0)

@property
def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the attn output biases across all layers"""
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_O for block in self.blocks], dim=0)

@property
def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
"""Stacks the MLP input biases across all layers"""
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.b_in for block in self.blocks], dim=0)
return torch.stack([cast(MLP, block.mlp).b_in for block in self.blocks], dim=0)

@property
def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the MLP output biases across all layers"""
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.b_out for block in self.blocks], dim=0)
return torch.stack([cast(MLP, block.mlp).b_out for block in self.blocks], dim=0)

@property
def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
Expand Down
35 changes: 15 additions & 20 deletions transformer_lens/HookedEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast, overload

import torch
import torch.nn as nn
from einops import repeat
from jaxtyping import Float, Int
from transformers.models.auto.tokenization_auto import AutoTokenizer
Expand All @@ -33,7 +32,7 @@
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookPoint
from transformer_lens.HookedRootModule import HookedRootModule
from transformer_lens.utilities import devices
from transformer_lens.utilities import TypedModuleList, devices

T = TypeVar("T", bound="HookedEncoder")

Expand All @@ -49,11 +48,7 @@ class HookedEncoder(HookedRootModule):
- There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
"""

blocks: nn.ModuleList[BertBlock] # type: ignore[type-arg]

def _get_blocks(self) -> list[BertBlock]:
"""Helper to get blocks with proper typing."""
return [cast(BertBlock, block) for block in self.blocks]
blocks: TypedModuleList[BertBlock]

def __init__(
self,
Expand Down Expand Up @@ -91,7 +86,7 @@ def __init__(
self.cfg.d_vocab_out = self.cfg.d_vocab

self.embed = BertEmbed(self.cfg)
self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])
self.blocks = TypedModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])
self.mlm_head = BertMLMHead(self.cfg)
self.unembed = Unembed(self.cfg)
self.nsp_head = BertNSPHead(self.cfg)
Expand Down Expand Up @@ -471,69 +466,69 @@ def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
@property
def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the key weights across all layers"""
return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.W_K for block in self.blocks], dim=0)

@property
def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the query weights across all layers"""
return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.W_Q for block in self.blocks], dim=0)

@property
def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the value weights across all layers"""
return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.W_V for block in self.blocks], dim=0)

@property
def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
"""Stacks the attn output weights across all layers"""
return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.W_O for block in self.blocks], dim=0)

@property
def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
"""Stacks the MLP input weights across all layers"""
return torch.stack(
[cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0
[cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self.blocks], dim=0
)

@property
def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
"""Stacks the MLP output weights across all layers"""
return torch.stack(
[cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0
[cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self.blocks], dim=0
)

@property
def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the key biases across all layers"""
return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.b_K for block in self.blocks], dim=0)

@property
def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the query biases across all layers"""
return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.b_Q for block in self.blocks], dim=0)

@property
def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the value biases across all layers"""
return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.b_V for block in self.blocks], dim=0)

@property
def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the attn output biases across all layers"""
return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0)
return torch.stack([block.attn.b_O for block in self.blocks], dim=0)

@property
def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
"""Stacks the MLP input biases across all layers"""
return torch.stack(
[cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0
[cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self.blocks], dim=0
)

@property
def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the MLP output biases across all layers"""
return torch.stack(
[cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0
[cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self.blocks], dim=0
)

@property
Expand Down
Loading
Loading