From bddafb6619ece49a181a5f7a393f0e1abbe8861f Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Wed, 17 Jun 2026 20:15:03 +0000 Subject: [PATCH 1/9] feat: add get_act_patch_direct_path for head-to-head circuit analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #111. Implements direct path patching — a finer-grained variant of activation patching that isolates the direct information flow between two specific attention heads, rather than replacing the full residual stream. Why --- Standard activation patching tells you that *some* component at layer L matters, but it cannot distinguish whether head B at layer L+2 matters because it received information directly from head A, or because A's output propagated through many intermediate components first. Direct path patching isolates the A → B causal edge precisely. Implementation -------------- For a fixed source head A = (src_layer, src_head) and every downstream destination head B = (dst_layer, dst_head): delta_resid = clean_A_result - corrupted_A_result # [batch, pos, d_model] delta_B_q = (delta_resid / ln1_scale) @ W_Q[hb] # [batch, pos, d_head] patched_B_q = corrupted_B_q + delta_B_q The per-head residual contribution is computed from hook_z @ W_O (always available in the default cache) rather than hook_result, which requires the non-default cfg.use_hook_result=True flag. New files --------- - transformer_lens/direct_path_patching.py get_act_patch_direct_path() [n_layers, n_heads] sweep get_act_patch_direct_path_all_sources() [n_layers, n_heads, n_layers, n_heads] full sweep - tests/unit/test_direct_path_patching.py 12 tests covering output shape, causal structure, manual correctness verification, and edge cases. All pass on a tiny randomly-initialised 3-layer model (no downloads, runs in ~3s on CPU). - demos/direct_path_patching_ioi.py Validated on GPT-2 small / IOI task. S-inhibition heads (7.3, 7.9, 8.6, 8.10) show strongest direct paths into name-mover heads (9.9, 9.6, 10.0), confirming the Wang et al. 2022 IOI circuit. (8,6) → (9,9): +0.083 normalised logit diff (8,10) → (9,9): +0.066 (7,9) → (9,9): +0.036 API matches existing get_act_patch_* functions in patching.py for drop-in use alongside the existing circuit analysis toolkit. --- demos/direct_path_patching_ioi.py | 142 ++++++++++ tests/unit/test_direct_path_patching.py | 337 +++++++++++++++++++++++ transformer_lens/__init__.py | 2 + transformer_lens/direct_path_patching.py | 289 +++++++++++++++++++ 4 files changed, 770 insertions(+) create mode 100644 demos/direct_path_patching_ioi.py create mode 100644 tests/unit/test_direct_path_patching.py create mode 100644 transformer_lens/direct_path_patching.py diff --git a/demos/direct_path_patching_ioi.py b/demos/direct_path_patching_ioi.py new file mode 100644 index 000000000..11db934ca --- /dev/null +++ b/demos/direct_path_patching_ioi.py @@ -0,0 +1,142 @@ +""" +Direct Path Patching — Real Experiment on GPT-2 Small (IOI task) +================================================================= + +Indirect Object Identification (IOI) task from Wang et al. 2022: + Clean: "When Mary and John went to the store, John gave a drink to" → Mary + Corrupted: "When Mary and John went to the store, Mary gave a drink to" → John + +We measure logit(Mary) - logit(John) as our metric. + +For each source head known to be important in the IOI circuit +(S-inhibition heads: 8.6, 8.10, 7.3, 7.9), we patch its output +directly into the queries of all downstream heads and see which +(src → dst) paths carry the most information. +""" + +import sys, os, importlib.util +import torch +import einops +from transformer_lens import HookedTransformer + +# Load our local module +_path = os.path.join(os.path.dirname(__file__), "..", "transformer_lens", "direct_path_patching.py") +_spec = importlib.util.spec_from_file_location("direct_path_patching", _path) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +get_act_patch_direct_path = _mod.get_act_patch_direct_path + +# --------------------------------------------------------------------------- +# 1. Load model +# --------------------------------------------------------------------------- +print("Loading GPT-2 small…") +model = HookedTransformer.from_pretrained("gpt2", center_unembed=True, center_writing_weights=True, fold_ln=True) +model.eval() + +# --------------------------------------------------------------------------- +# 2. Define IOI prompts +# --------------------------------------------------------------------------- +# Classic IOI pair from Wang et al. 2022 +CLEAN_PROMPT = "When Mary and John went to the store, John gave a drink to" +CORRUPTED_PROMPT = "When Mary and John went to the store, Mary gave a drink to" + +clean_tokens = model.to_tokens(CLEAN_PROMPT) # [1, seq] +corrupted_tokens = model.to_tokens(CORRUPTED_PROMPT) # [1, seq] + +# Token IDs for Mary and John +mary_token = model.to_single_token(" Mary") +john_token = model.to_single_token(" John") +print(f"Mary token id: {mary_token}, John token id: {john_token}") + +# --------------------------------------------------------------------------- +# 3. Metric: logit(Mary) - logit(John) at the last token position +# --------------------------------------------------------------------------- +def logit_diff(logits): + """Higher = more correct (predicts Mary over John).""" + last = logits[0, -1, :] + return last[mary_token] - last[john_token] + +# Baselines +with torch.no_grad(): + clean_logits, clean_cache = model.run_with_cache(clean_tokens) + corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens) + +clean_baseline = logit_diff(clean_logits).item() +corrupted_baseline = logit_diff(corrupted_logits).item() +print(f"\nClean logit diff: {clean_baseline:+.3f} (correctly prefers Mary)") +print(f"Corrupted logit diff: {corrupted_baseline:+.3f} (incorrectly prefers John)") + +# Normalised metric: 0 = corrupted baseline, 1 = clean baseline +def normalised_metric(logits): + raw = logit_diff(logits) + return (raw - corrupted_baseline) / (clean_baseline - corrupted_baseline) + +# --------------------------------------------------------------------------- +# 4. Run direct path patching for known IOI circuit heads +# --------------------------------------------------------------------------- +# S-inhibition heads that write to the Name Mover heads' queries +ioi_src_heads = [ + (7, 3), # S-inhibition head + (7, 9), # S-inhibition head + (8, 6), # S-inhibition head + (8, 10), # S-inhibition head +] + +print("\n" + "="*60) +print("DIRECT PATH PATCHING RESULTS") +print("Metric: normalised logit diff (0=corrupted, 1=clean)") +print("="*60) + +all_results = {} +for (sl, sh) in ioi_src_heads: + print(f"\nSource head ({sl},{sh}) → all downstream heads [Q input]") + with torch.no_grad(): + results = get_act_patch_direct_path( + model=model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=normalised_metric, + src_layer=sl, + src_head=sh, + component="q", + verbose=True, + ) + all_results[(sl, sh)] = results + + # Top 5 (dst_layer, dst_head) destinations + flat = results.view(-1) + top5_idx = flat.topk(5).indices + print(f" Top 5 destinations (by normalised metric):") + for idx in top5_idx: + dl = idx.item() // model.cfg.n_heads + dh = idx.item() % model.cfg.n_heads + val = results[dl, dh].item() + print(f" ({sl},{sh}) → ({dl},{dh}): {val:+.4f}") + +# --------------------------------------------------------------------------- +# 5. Summary: Known name-mover head query inputs +# --------------------------------------------------------------------------- +name_movers = [(9, 9), (9, 6), (10, 0)] # confirmed in IOI paper + +print("\n" + "="*60) +print("DIRECT PATH: S-inhibition → Name-Mover (Q) scores") +print("Expected: strong signal for known circuit edges") +print("="*60) +print(f"{'Src head':>10} {'Dst head':>10} {'Score':>8}") +print("-"*34) +for (sl, sh), results in all_results.items(): + for (dl, dh) in name_movers: + if dl > sl: + val = results[dl, dh].item() + print(f" ({sl:2d},{sh:2d}) → ({dl:2d},{dh:2d}) {val:+.4f}") + +# --------------------------------------------------------------------------- +# 6. Save results +# --------------------------------------------------------------------------- +torch.save( + {k: v for k, v in all_results.items()}, + os.path.join(os.path.dirname(__file__), "results_direct_path_ioi.pt"), +) +print("\nResults saved to demos/results_direct_path_ioi.pt") +print("\nDone.") diff --git a/tests/unit/test_direct_path_patching.py b/tests/unit/test_direct_path_patching.py new file mode 100644 index 000000000..3f7348da3 --- /dev/null +++ b/tests/unit/test_direct_path_patching.py @@ -0,0 +1,337 @@ +"""Tests for direct_path_patching.py + +Run with: + pytest tests/test_direct_path_patching.py -v + +These tests use a tiny randomly-initialised 2-layer GPT-2 config so they run +in seconds on CPU without downloading any weights. +""" + +import pytest +import torch +from transformer_lens import HookedTransformer, HookedTransformerConfig + +from transformer_lens.direct_path_patching import ( + get_act_patch_direct_path, + get_act_patch_direct_path_all_sources, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def tiny_model(): + """A small, randomly-initialised transformer for fast tests.""" + cfg = HookedTransformerConfig( + n_layers=3, + d_model=64, + d_head=16, + n_heads=4, + d_mlp=128, + d_vocab=100, + n_ctx=16, + act_fn="gelu", + normalization_type="LN", + attn_only=False, + ) + model = HookedTransformer(cfg) + model.eval() + return model + + +@pytest.fixture(scope="module") +def tokens_and_caches(tiny_model): + """Precompute clean/corrupted tokens and their caches.""" + torch.manual_seed(42) + clean_tokens = torch.randint(0, 100, (1, 8)) + corrupted_tokens = torch.randint(0, 100, (1, 8)) + + with torch.no_grad(): + _, clean_cache = tiny_model.run_with_cache(clean_tokens) + _, corrupted_cache = tiny_model.run_with_cache(corrupted_tokens) + + return clean_tokens, corrupted_tokens, clean_cache, corrupted_cache + + +def simple_metric(logits): + """Sum of last-token logits — a trivially differentiable scalar.""" + return logits[0, -1, :].sum() + + +# --------------------------------------------------------------------------- +# Shape tests +# --------------------------------------------------------------------------- + + +class TestOutputShape: + def test_single_source_shape(self, tiny_model, tokens_and_caches): + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=0, + src_head=0, + component="q", + verbose=False, + ) + assert results.shape == (tiny_model.cfg.n_layers, tiny_model.cfg.n_heads), ( + f"Expected ({tiny_model.cfg.n_layers}, {tiny_model.cfg.n_heads}), got {results.shape}" + ) + + def test_all_sources_shape(self, tiny_model, tokens_and_caches): + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + with torch.no_grad(): + results = get_act_patch_direct_path_all_sources( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + component="q", + verbose=False, + ) + n = tiny_model.cfg.n_layers + h = tiny_model.cfg.n_heads + assert results.shape == (n, h, n, h), ( + f"Expected ({n},{h},{n},{h}), got {results.shape}" + ) + + @pytest.mark.parametrize("component", ["q", "k", "v"]) + def test_all_components(self, tiny_model, tokens_and_caches, component): + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=0, + src_head=1, + component=component, + verbose=False, + ) + assert results.shape == (tiny_model.cfg.n_layers, tiny_model.cfg.n_heads) + + +# --------------------------------------------------------------------------- +# Causal structure tests +# --------------------------------------------------------------------------- + + +class TestCausalStructure: + def test_earlier_layers_are_zero(self, tiny_model, tokens_and_caches): + """Entries for dst_layer <= src_layer must be exactly 0 (no causal path).""" + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + src_layer = 1 + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=src_layer, + src_head=0, + component="q", + verbose=False, + ) + # Rows 0..src_layer (inclusive) should be exactly 0 + assert results[: src_layer + 1].eq(0).all(), ( + "Expected zero for dst_layer <= src_layer, but got non-zero entries." + ) + + def test_later_layers_are_nonzero(self, tiny_model, tokens_and_caches): + """At least some entries for dst_layer > src_layer should be non-zero + when clean != corrupted.""" + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + src_layer = 0 + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=src_layer, + src_head=0, + component="q", + verbose=False, + ) + downstream = results[src_layer + 1 :] + assert not downstream.eq(0).all(), ( + "Expected at least some non-zero values for downstream layers, " + "but all were zero (this is extremely unlikely with random weights)." + ) + + def test_clean_equals_corrupted_gives_zero_delta(self, tiny_model): + """If clean == corrupted, the delta is zero and the metric should be + identical for all pairs (patching in nothing should change nothing).""" + torch.manual_seed(7) + tokens = torch.randint(0, 100, (1, 6)) + with torch.no_grad(): + baseline_logits, cache = tiny_model.run_with_cache(tokens) + + baseline = simple_metric(baseline_logits).item() + + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=tokens, + clean_cache=cache, + corrupted_cache=cache, # same cache → delta = 0 + patching_metric=simple_metric, + src_layer=0, + src_head=0, + component="q", + verbose=False, + ) + + # Every entry should equal the baseline (delta is zero, so hooks do nothing) + nonzero_entries = results[results != 0] + assert nonzero_entries.numel() == 0 or torch.allclose( + nonzero_entries, + torch.full_like(nonzero_entries, baseline), + atol=1e-4, + ), "When clean==corrupted, metric should equal baseline for all pairs." + + +# --------------------------------------------------------------------------- +# Correctness: manual verification for a single pair +# --------------------------------------------------------------------------- + + +class TestCorrectness: + def test_manual_patch_matches_function(self, tiny_model, tokens_and_caches): + """Manually apply the linear approximation for one (src→dst) pair and + compare with the function's output.""" + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + src_layer, src_head = 0, 0 + dst_layer, dst_head = 2, 1 + component = "q" + + # --- Manually compute the expected patched metric --- + W_O = tiny_model.blocks[src_layer].attn.W_O # [n_heads, d_head, d_model] + def _head_result(cache, h): + z = cache[f"blocks.{src_layer}.attn.hook_z"][:, :, h, :] + return z @ W_O[h] + delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head) + ln_scale = corrupted_cache[f"blocks.{dst_layer}.ln1.hook_scale"] + W_Q_dst = tiny_model.blocks[dst_layer].attn.W_Q[dst_head] # [d_model, d_head] + delta_q = (delta_resid / ln_scale) @ W_Q_dst + + def manual_hook(value, hook): + if value.requires_grad: + value = value.clone() + value[:, :, dst_head, :] = value[:, :, dst_head, :] + delta_q + return value + + with torch.no_grad(): + patched_logits = tiny_model.run_with_hooks( + corrupted_tokens, + fwd_hooks=[(f"blocks.{dst_layer}.attn.hook_q", manual_hook)], + ) + expected = simple_metric(patched_logits).item() + + # --- Get the function's result --- + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=src_layer, + src_head=src_head, + component=component, + verbose=False, + ) + + actual = results[dst_layer, dst_head].item() + assert abs(actual - expected) < 1e-4, ( + f"Manual patch gave {expected:.6f} but function gave {actual:.6f}" + ) + + def test_all_sources_consistent_with_single(self, tiny_model, tokens_and_caches): + """get_act_patch_direct_path_all_sources should give the same result as + calling get_act_patch_direct_path for each source individually.""" + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + + src_layer, src_head = 0, 2 + + with torch.no_grad(): + single = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=src_layer, + src_head=src_head, + component="q", + verbose=False, + ) + all_sources = get_act_patch_direct_path_all_sources( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + component="q", + verbose=False, + ) + + assert torch.allclose( + single, all_sources[src_layer, src_head], atol=1e-5 + ), "all_sources result doesn't match single-source call." + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_last_layer_source_all_zero(self, tiny_model, tokens_and_caches): + """A source head in the last layer has no downstream heads → all zeros.""" + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + src_layer = tiny_model.cfg.n_layers - 1 + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=src_layer, + src_head=0, + component="q", + verbose=False, + ) + assert results.eq(0).all(), ( + "Source in last layer should produce all-zero results (no downstream)." + ) + + def test_returns_cpu_tensor(self, tiny_model, tokens_and_caches): + """Return tensor should be on the same device as the model.""" + _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches + with torch.no_grad(): + results = get_act_patch_direct_path( + model=tiny_model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=simple_metric, + src_layer=0, + src_head=0, + component="q", + verbose=False, + ) + assert results.device.type == tiny_model.cfg.device or results.is_cpu diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 53f7fbe87..37b252204 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -1,6 +1,7 @@ from . import ( components, conversion_utils, + direct_path_patching, evals, factories, head_detector, @@ -55,4 +56,5 @@ "conversion_utils", "factories", "utilities", + "direct_path_patching", ] diff --git a/transformer_lens/direct_path_patching.py b/transformer_lens/direct_path_patching.py new file mode 100644 index 000000000..6162f7b20 --- /dev/null +++ b/transformer_lens/direct_path_patching.py @@ -0,0 +1,289 @@ +"""Direct Path Patching. + +Implements direct path patching — a finer-grained variant of activation patching +introduced for circuit analysis. + +Background +---------- +Standard activation patching (see patching.py) replaces an activation at a given +layer/position with its value from a clean run, and measures how much the model's +output shifts. But patching the *residual stream* affects ALL downstream components, +making it hard to isolate the direct information flow between two specific heads. + +Direct path patching isolates the path A → B: it patches *only* the contribution of +source head A (at layer src_layer) into the input of destination head B (at layer +dst_layer > src_layer), leaving every other component's view of A's output unchanged. + +The linear approximation used here (following Neel Nanda's description in issue #111) +is: + + delta_resid = clean_A_result - corrupted_A_result # [batch, pos, d_model] + delta_q = (delta_resid / ln1_scale) @ W_Q[hb] # [batch, pos, d_head] + patched_q = corrupted_q + delta_q + +This is exact under linear layer norm (no learned offset changes the scale +in a way that matters for the perturbation), and matches the gradient-based +approximation used in attribution patching. + +Usage +----- + # 1. Cache clean and corrupted activations + _, clean_cache = model.run_with_cache(clean_tokens) + _, corrupted_cache = model.run_with_cache(corrupted_tokens) + + # 2. Define your metric (same as activation patching) + def metric(logits): + return logit_diff(logits, ...) + + # 3. Sweep all (dst_layer, dst_head) pairs for a fixed source head + results = get_act_patch_direct_path( + model, corrupted_tokens, clean_cache, corrupted_cache, + metric, src_layer=9, src_head=9, + component="q", # patch into Q; also supports "k", "v" + ) + # results.shape == (n_layers, n_heads) + # results[dst_layer, dst_head] = metric when A→B path is patched + +References +---------- +- Neel Nanda, TransformerLens issue #111 (2022) +- Wang et al., "Interpretability in the Wild: a Circuit for Indirect Object + Identification in GPT-2 small" (2022) +""" + +from __future__ import annotations + +from functools import partial +from typing import Callable, Literal, Optional + +import torch +from jaxtyping import Float +from tqdm.auto import tqdm + +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.HookedTransformer import HookedTransformer + + +# --------------------------------------------------------------------------- +# Core hook factory +# --------------------------------------------------------------------------- + + +def _make_direct_path_hook( + delta_resid: Float[torch.Tensor, "batch pos d_model"], + dst_head: int, + W_component: Float[torch.Tensor, "d_model d_head"], + ln_scale_name: str, + corrupted_cache: ActivationCache, + component: Literal["q", "k", "v"], +) -> Callable: + """Return a hook function that adds the linearised delta to one head's Q, K, or V. + + Parameters + ---------- + delta_resid: + (clean_A_result - corrupted_A_result), shape [batch, pos, d_model]. + dst_head: + Index of the destination attention head to patch. + W_component: + The weight matrix for the component being patched: + W_Q[dst_head], W_K[dst_head], or W_V[dst_head]. + Shape [d_model, d_head]. + ln_scale_name: + Cache key for the layer-norm scale at the destination layer, + e.g. "blocks.3.ln1.hook_scale". + corrupted_cache: + Cache from the corrupted forward pass (used to look up ln1 scale). + component: + One of "q", "k", "v" — determines which QKV tensor is hooked. + """ + + def hook_fn( + value: Float[torch.Tensor, "batch pos n_heads d_head"], + hook, # HookPoint, unused but required by TransformerLens + ) -> Float[torch.Tensor, "batch pos n_heads d_head"]: + # ln scale: [batch, pos, 1] + ln_scale = corrupted_cache[ln_scale_name] # [batch, pos, 1] + + # Linearised delta in query/key/value space + # delta_resid: [batch, pos, d_model] + # W_component: [d_model, d_head] + delta = (delta_resid / ln_scale) @ W_component # [batch, pos, d_head] + + if value.requires_grad: + value = value.clone() + value[:, :, dst_head, :] = value[:, :, dst_head, :] + delta + return value + + return hook_fn + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def get_act_patch_direct_path( + model: HookedTransformer, + corrupted_tokens: torch.Tensor, + clean_cache: ActivationCache, + corrupted_cache: ActivationCache, + patching_metric: Callable[[torch.Tensor], torch.Tensor], + src_layer: int, + src_head: int, + component: Literal["q", "k", "v"] = "q", + verbose: bool = True, +) -> Float[torch.Tensor, "n_layers n_heads"]: + """Sweep direct path patches from one source head to all downstream heads. + + For every destination head B = (dst_layer, dst_head) where dst_layer > src_layer, + patch the contribution of source head A = (src_layer, src_head) into B's query + (or key / value) input, and record the patching metric. + + The patch is a linear approximation: + + delta_resid = clean_A_result - corrupted_A_result [batch, pos, d_model] + delta_B_comp = (delta_resid / ln1_scale) @ W_comp[dst_head] + + where W_comp is W_Q, W_K, or W_V according to `component`. + + Parameters + ---------- + model: + A HookedTransformer. + corrupted_tokens: + Token IDs for the corrupted input, shape [batch, seq_len]. + clean_cache: + Cached activations from the clean (unpatched) run. + corrupted_cache: + Cached activations from the corrupted run (needed for ln1 scale). + patching_metric: + A function mapping the model's logits tensor to a scalar. + src_layer: + Layer index of the source attention head. + src_head: + Head index of the source attention head. + component: + Which input to patch at the destination head — "q" (default), "k", or "v". + verbose: + Whether to show a tqdm progress bar. + + Returns + ------- + results : Float[Tensor, "n_layers n_heads"] + results[dst_layer, dst_head] is the patching metric when the direct path + A → B is patched in. Entries for dst_layer <= src_layer are left as 0.0 + (no causal path from A to those layers). + """ + n_layers = model.cfg.n_layers + n_heads = model.cfg.n_heads + + results = torch.zeros(n_layers, n_heads, device=model.cfg.device) + + # Residual stream delta from source head A. + # + # hook_result (per-head residual contribution) requires cfg.use_hook_result=True + # and is not in the default cache. We compute it instead from hook_z and W_O, + # which are always available: + # result_h = z[:, :, h, :] @ W_O[h] shape [batch, pos, d_model] + src_z_name = f"blocks.{src_layer}.attn.hook_z" + W_O = model.blocks[src_layer].attn.W_O # [n_heads, d_head, d_model] + + def _head_result(cache, h): + z = cache[src_z_name][:, :, h, :] # [batch, pos, d_head] + return z @ W_O[h] # [batch, pos, d_model] + + delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head) + # shape: [batch, pos, d_model] + + # Weight matrix for the component being patched + _comp_map = { + "q": lambda attn: attn.W_Q, # [n_heads, d_model, d_head] + "k": lambda attn: attn.W_K, + "v": lambda attn: attn.W_V, + } + _hook_name_map = { + "q": lambda lb: f"blocks.{lb}.attn.hook_q", + "k": lambda lb: f"blocks.{lb}.attn.hook_k", + "v": lambda lb: f"blocks.{lb}.attn.hook_v", + } + W_all = _comp_map[component] # callable: attn → [n_heads, d_model, d_head] + hook_name_fn = _hook_name_map[component] + + dst_pairs = [ + (lb, hb) + for lb in range(src_layer + 1, n_layers) + for hb in range(n_heads) + ] + + for dst_layer, dst_head in tqdm( + dst_pairs, + desc=f"Direct path patch ({src_layer},{src_head}) → * [{component}]", + disable=not verbose, + ): + ln_scale_name = f"blocks.{dst_layer}.ln1.hook_scale" + W_comp = W_all(model.blocks[dst_layer].attn)[dst_head] # [d_model, d_head] + + hook_fn = _make_direct_path_hook( + delta_resid=delta_resid, + dst_head=dst_head, + W_component=W_comp, + ln_scale_name=ln_scale_name, + corrupted_cache=corrupted_cache, + component=component, + ) + + patched_logits = model.run_with_hooks( + corrupted_tokens, + fwd_hooks=[(hook_name_fn(dst_layer), hook_fn)], + ) + + results[dst_layer, dst_head] = patching_metric(patched_logits).item() + + return results + + +def get_act_patch_direct_path_all_sources( + model: HookedTransformer, + corrupted_tokens: torch.Tensor, + clean_cache: ActivationCache, + corrupted_cache: ActivationCache, + patching_metric: Callable[[torch.Tensor], torch.Tensor], + component: Literal["q", "k", "v"] = "q", + verbose: bool = True, +) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: + """Full sweep: all (src_layer, src_head) → (dst_layer, dst_head) direct paths. + + Returns a 4-D tensor of shape [n_layers, n_heads, n_layers, n_heads]. + result[sl, sh, dl, dh] = patching metric when head (sl,sh)'s output is + patched directly into head (dl,dh)'s query/key/value input. + + Entries where dl <= sl are 0 (no causal path). + + This runs O(n_layers * n_heads * n_layers * n_heads) forward passes and is + intended for small models or targeted sub-sweeps. For large models prefer + calling get_act_patch_direct_path per source head. + """ + n_layers = model.cfg.n_layers + n_heads = model.cfg.n_heads + results = torch.zeros(n_layers, n_heads, n_layers, n_heads, device=model.cfg.device) + + src_pairs = [(sl, sh) for sl in range(n_layers) for sh in range(n_heads)] + for src_layer, src_head in tqdm( + src_pairs, + desc=f"Direct path patch — all sources [{component}]", + disable=not verbose, + ): + results[src_layer, src_head] = get_act_patch_direct_path( + model=model, + corrupted_tokens=corrupted_tokens, + clean_cache=clean_cache, + corrupted_cache=corrupted_cache, + patching_metric=patching_metric, + src_layer=src_layer, + src_head=src_head, + component=component, + verbose=False, + ) + + return results From d92392978287c98d9241edba93f85f3f0cd58762 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Wed, 17 Jun 2026 21:48:45 +0000 Subject: [PATCH 2/9] style: apply black + isort formatting (line-length=100) --- demos/direct_path_patching_ioi.py | 48 ++++++++++++++---------- tests/unit/test_direct_path_patching.py | 36 +++++++++--------- transformer_lens/direct_path_patching.py | 9 +---- 3 files changed, 49 insertions(+), 44 deletions(-) diff --git a/demos/direct_path_patching_ioi.py b/demos/direct_path_patching_ioi.py index 11db934ca..c80ac9d40 100644 --- a/demos/direct_path_patching_ioi.py +++ b/demos/direct_path_patching_ioi.py @@ -14,15 +14,19 @@ (src → dst) paths carry the most information. """ -import sys, os, importlib.util -import torch +import importlib.util +import os +import sys + import einops +import torch + from transformer_lens import HookedTransformer # Load our local module _path = os.path.join(os.path.dirname(__file__), "..", "transformer_lens", "direct_path_patching.py") _spec = importlib.util.spec_from_file_location("direct_path_patching", _path) -_mod = importlib.util.module_from_spec(_spec) +_mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_mod) get_act_patch_direct_path = _mod.get_act_patch_direct_path @@ -30,17 +34,19 @@ # 1. Load model # --------------------------------------------------------------------------- print("Loading GPT-2 small…") -model = HookedTransformer.from_pretrained("gpt2", center_unembed=True, center_writing_weights=True, fold_ln=True) +model = HookedTransformer.from_pretrained( + "gpt2", center_unembed=True, center_writing_weights=True, fold_ln=True +) model.eval() # --------------------------------------------------------------------------- # 2. Define IOI prompts # --------------------------------------------------------------------------- # Classic IOI pair from Wang et al. 2022 -CLEAN_PROMPT = "When Mary and John went to the store, John gave a drink to" +CLEAN_PROMPT = "When Mary and John went to the store, John gave a drink to" CORRUPTED_PROMPT = "When Mary and John went to the store, Mary gave a drink to" -clean_tokens = model.to_tokens(CLEAN_PROMPT) # [1, seq] +clean_tokens = model.to_tokens(CLEAN_PROMPT) # [1, seq] corrupted_tokens = model.to_tokens(CORRUPTED_PROMPT) # [1, seq] # Token IDs for Mary and John @@ -48,6 +54,7 @@ john_token = model.to_single_token(" John") print(f"Mary token id: {mary_token}, John token id: {john_token}") + # --------------------------------------------------------------------------- # 3. Metric: logit(Mary) - logit(John) at the last token position # --------------------------------------------------------------------------- @@ -56,39 +63,42 @@ def logit_diff(logits): last = logits[0, -1, :] return last[mary_token] - last[john_token] + # Baselines with torch.no_grad(): - clean_logits, clean_cache = model.run_with_cache(clean_tokens) + clean_logits, clean_cache = model.run_with_cache(clean_tokens) corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens) -clean_baseline = logit_diff(clean_logits).item() +clean_baseline = logit_diff(clean_logits).item() corrupted_baseline = logit_diff(corrupted_logits).item() print(f"\nClean logit diff: {clean_baseline:+.3f} (correctly prefers Mary)") print(f"Corrupted logit diff: {corrupted_baseline:+.3f} (incorrectly prefers John)") + # Normalised metric: 0 = corrupted baseline, 1 = clean baseline def normalised_metric(logits): raw = logit_diff(logits) return (raw - corrupted_baseline) / (clean_baseline - corrupted_baseline) + # --------------------------------------------------------------------------- # 4. Run direct path patching for known IOI circuit heads # --------------------------------------------------------------------------- # S-inhibition heads that write to the Name Mover heads' queries ioi_src_heads = [ - (7, 3), # S-inhibition head - (7, 9), # S-inhibition head - (8, 6), # S-inhibition head + (7, 3), # S-inhibition head + (7, 9), # S-inhibition head + (8, 6), # S-inhibition head (8, 10), # S-inhibition head ] -print("\n" + "="*60) +print("\n" + "=" * 60) print("DIRECT PATH PATCHING RESULTS") print("Metric: normalised logit diff (0=corrupted, 1=clean)") -print("="*60) +print("=" * 60) all_results = {} -for (sl, sh) in ioi_src_heads: +for sl, sh in ioi_src_heads: print(f"\nSource head ({sl},{sh}) → all downstream heads [Q input]") with torch.no_grad(): results = get_act_patch_direct_path( @@ -110,7 +120,7 @@ def normalised_metric(logits): print(f" Top 5 destinations (by normalised metric):") for idx in top5_idx: dl = idx.item() // model.cfg.n_heads - dh = idx.item() % model.cfg.n_heads + dh = idx.item() % model.cfg.n_heads val = results[dl, dh].item() print(f" ({sl},{sh}) → ({dl},{dh}): {val:+.4f}") @@ -119,14 +129,14 @@ def normalised_metric(logits): # --------------------------------------------------------------------------- name_movers = [(9, 9), (9, 6), (10, 0)] # confirmed in IOI paper -print("\n" + "="*60) +print("\n" + "=" * 60) print("DIRECT PATH: S-inhibition → Name-Mover (Q) scores") print("Expected: strong signal for known circuit edges") -print("="*60) +print("=" * 60) print(f"{'Src head':>10} {'Dst head':>10} {'Score':>8}") -print("-"*34) +print("-" * 34) for (sl, sh), results in all_results.items(): - for (dl, dh) in name_movers: + for dl, dh in name_movers: if dl > sl: val = results[dl, dh].item() print(f" ({sl:2d},{sh:2d}) → ({dl:2d},{dh:2d}) {val:+.4f}") diff --git a/tests/unit/test_direct_path_patching.py b/tests/unit/test_direct_path_patching.py index 3f7348da3..c3dd73f42 100644 --- a/tests/unit/test_direct_path_patching.py +++ b/tests/unit/test_direct_path_patching.py @@ -9,14 +9,13 @@ import pytest import torch -from transformer_lens import HookedTransformer, HookedTransformerConfig +from transformer_lens import HookedTransformer, HookedTransformerConfig from transformer_lens.direct_path_patching import ( get_act_patch_direct_path, get_act_patch_direct_path_all_sources, ) - # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -81,9 +80,10 @@ def test_single_source_shape(self, tiny_model, tokens_and_caches): component="q", verbose=False, ) - assert results.shape == (tiny_model.cfg.n_layers, tiny_model.cfg.n_heads), ( - f"Expected ({tiny_model.cfg.n_layers}, {tiny_model.cfg.n_heads}), got {results.shape}" - ) + assert results.shape == ( + tiny_model.cfg.n_layers, + tiny_model.cfg.n_heads, + ), f"Expected ({tiny_model.cfg.n_layers}, {tiny_model.cfg.n_heads}), got {results.shape}" def test_all_sources_shape(self, tiny_model, tokens_and_caches): _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches @@ -99,9 +99,7 @@ def test_all_sources_shape(self, tiny_model, tokens_and_caches): ) n = tiny_model.cfg.n_layers h = tiny_model.cfg.n_heads - assert results.shape == (n, h, n, h), ( - f"Expected ({n},{h},{n},{h}), got {results.shape}" - ) + assert results.shape == (n, h, n, h), f"Expected ({n},{h},{n},{h}), got {results.shape}" @pytest.mark.parametrize("component", ["q", "k", "v"]) def test_all_components(self, tiny_model, tokens_and_caches, component): @@ -144,9 +142,9 @@ def test_earlier_layers_are_zero(self, tiny_model, tokens_and_caches): verbose=False, ) # Rows 0..src_layer (inclusive) should be exactly 0 - assert results[: src_layer + 1].eq(0).all(), ( - "Expected zero for dst_layer <= src_layer, but got non-zero entries." - ) + assert ( + results[: src_layer + 1].eq(0).all() + ), "Expected zero for dst_layer <= src_layer, but got non-zero entries." def test_later_layers_are_nonzero(self, tiny_model, tokens_and_caches): """At least some entries for dst_layer > src_layer should be non-zero @@ -186,7 +184,7 @@ def test_clean_equals_corrupted_gives_zero_delta(self, tiny_model): model=tiny_model, corrupted_tokens=tokens, clean_cache=cache, - corrupted_cache=cache, # same cache → delta = 0 + corrupted_cache=cache, # same cache → delta = 0 patching_metric=simple_metric, src_layer=0, src_head=0, @@ -219,9 +217,11 @@ def test_manual_patch_matches_function(self, tiny_model, tokens_and_caches): # --- Manually compute the expected patched metric --- W_O = tiny_model.blocks[src_layer].attn.W_O # [n_heads, d_head, d_model] + def _head_result(cache, h): z = cache[f"blocks.{src_layer}.attn.hook_z"][:, :, h, :] return z @ W_O[h] + delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head) ln_scale = corrupted_cache[f"blocks.{dst_layer}.ln1.hook_scale"] W_Q_dst = tiny_model.blocks[dst_layer].attn.W_Q[dst_head] # [d_model, d_head] @@ -255,9 +255,9 @@ def manual_hook(value, hook): ) actual = results[dst_layer, dst_head].item() - assert abs(actual - expected) < 1e-4, ( - f"Manual patch gave {expected:.6f} but function gave {actual:.6f}" - ) + assert ( + abs(actual - expected) < 1e-4 + ), f"Manual patch gave {expected:.6f} but function gave {actual:.6f}" def test_all_sources_consistent_with_single(self, tiny_model, tokens_and_caches): """get_act_patch_direct_path_all_sources should give the same result as @@ -315,9 +315,9 @@ def test_last_layer_source_all_zero(self, tiny_model, tokens_and_caches): component="q", verbose=False, ) - assert results.eq(0).all(), ( - "Source in last layer should produce all-zero results (no downstream)." - ) + assert results.eq( + 0 + ).all(), "Source in last layer should produce all-zero results (no downstream)." def test_returns_cpu_tensor(self, tiny_model, tokens_and_caches): """Return tensor should be on the same device as the model.""" diff --git a/transformer_lens/direct_path_patching.py b/transformer_lens/direct_path_patching.py index 6162f7b20..76dc11360 100644 --- a/transformer_lens/direct_path_patching.py +++ b/transformer_lens/direct_path_patching.py @@ -63,7 +63,6 @@ def metric(logits): from transformer_lens.ActivationCache import ActivationCache from transformer_lens.HookedTransformer import HookedTransformer - # --------------------------------------------------------------------------- # Core hook factory # --------------------------------------------------------------------------- @@ -191,7 +190,7 @@ def get_act_patch_direct_path( def _head_result(cache, h): z = cache[src_z_name][:, :, h, :] # [batch, pos, d_head] - return z @ W_O[h] # [batch, pos, d_model] + return z @ W_O[h] # [batch, pos, d_model] delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head) # shape: [batch, pos, d_model] @@ -210,11 +209,7 @@ def _head_result(cache, h): W_all = _comp_map[component] # callable: attn → [n_heads, d_model, d_head] hook_name_fn = _hook_name_map[component] - dst_pairs = [ - (lb, hb) - for lb in range(src_layer + 1, n_layers) - for hb in range(n_heads) - ] + dst_pairs = [(lb, hb) for lb in range(src_layer + 1, n_layers) for hb in range(n_heads)] for dst_layer, dst_head in tqdm( dst_pairs, From d8c3d476882fe9e89d6bbf7d1988dd9f96066145 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Wed, 17 Jun 2026 21:54:47 +0000 Subject: [PATCH 3/9] fix: remove unused imports, add type: ignore for mypy, clean up demo import --- demos/direct_path_patching_ioi.py | 13 +------------ transformer_lens/direct_path_patching.py | 7 +++---- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/demos/direct_path_patching_ioi.py b/demos/direct_path_patching_ioi.py index c80ac9d40..d850f7f23 100644 --- a/demos/direct_path_patching_ioi.py +++ b/demos/direct_path_patching_ioi.py @@ -14,21 +14,10 @@ (src → dst) paths carry the most information. """ -import importlib.util -import os -import sys - -import einops import torch from transformer_lens import HookedTransformer - -# Load our local module -_path = os.path.join(os.path.dirname(__file__), "..", "transformer_lens", "direct_path_patching.py") -_spec = importlib.util.spec_from_file_location("direct_path_patching", _path) -_mod = importlib.util.module_from_spec(_spec) -_spec.loader.exec_module(_mod) -get_act_patch_direct_path = _mod.get_act_patch_direct_path +from transformer_lens.direct_path_patching import get_act_patch_direct_path # --------------------------------------------------------------------------- # 1. Load model diff --git a/transformer_lens/direct_path_patching.py b/transformer_lens/direct_path_patching.py index 76dc11360..ea2bdf4e2 100644 --- a/transformer_lens/direct_path_patching.py +++ b/transformer_lens/direct_path_patching.py @@ -53,8 +53,7 @@ def metric(logits): from __future__ import annotations -from functools import partial -from typing import Callable, Literal, Optional +from typing import Callable, Literal import torch from jaxtyping import Float @@ -186,7 +185,7 @@ def get_act_patch_direct_path( # which are always available: # result_h = z[:, :, h, :] @ W_O[h] shape [batch, pos, d_model] src_z_name = f"blocks.{src_layer}.attn.hook_z" - W_O = model.blocks[src_layer].attn.W_O # [n_heads, d_head, d_model] + W_O = model.blocks[src_layer].attn.W_O # type: ignore[union-attr] # [n_heads, d_head, d_model] def _head_result(cache, h): z = cache[src_z_name][:, :, h, :] # [batch, pos, d_head] @@ -217,7 +216,7 @@ def _head_result(cache, h): disable=not verbose, ): ln_scale_name = f"blocks.{dst_layer}.ln1.hook_scale" - W_comp = W_all(model.blocks[dst_layer].attn)[dst_head] # [d_model, d_head] + W_comp = W_all(model.blocks[dst_layer].attn)[dst_head] # type: ignore[index] # [d_model, d_head] hook_fn = _make_direct_path_hook( delta_resid=delta_resid, From 38b69e1cd09f0b4bf60d4151c8cd26e6d12aef0d Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Wed, 17 Jun 2026 22:06:18 +0000 Subject: [PATCH 4/9] fix: add type: ignore[index] on W_O[h] indexing for mypy --- transformer_lens/direct_path_patching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/direct_path_patching.py b/transformer_lens/direct_path_patching.py index ea2bdf4e2..648604759 100644 --- a/transformer_lens/direct_path_patching.py +++ b/transformer_lens/direct_path_patching.py @@ -189,7 +189,7 @@ def get_act_patch_direct_path( def _head_result(cache, h): z = cache[src_z_name][:, :, h, :] # [batch, pos, d_head] - return z @ W_O[h] # [batch, pos, d_model] + return z @ W_O[h] # type: ignore[index] # [batch, pos, d_model] delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head) # shape: [batch, pos, d_model] From 503056b458fb5dd74eb824cb143ecc5d9c71e8df Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Thu, 18 Jun 2026 06:06:50 +0000 Subject: [PATCH 5/9] =?UTF-8?q?fix:=20address=20reviewer=20feedback=20?= =?UTF-8?q?=E2=80=94=20TransformerBridge=20support,=20fold=5Fln=20guard,?= =?UTF-8?q?=20independent=20test,=20notebook=20demo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- demos/direct_path_patching_ioi.ipynb | 214 +++++++++++++++++++++++ demos/direct_path_patching_ioi.py | 141 --------------- tests/unit/test_direct_path_patching.py | 61 ++++--- transformer_lens/direct_path_patching.py | 38 +++- 4 files changed, 283 insertions(+), 171 deletions(-) create mode 100644 demos/direct_path_patching_ioi.ipynb delete mode 100644 demos/direct_path_patching_ioi.py diff --git a/demos/direct_path_patching_ioi.ipynb b/demos/direct_path_patching_ioi.ipynb new file mode 100644 index 000000000..9ab323a98 --- /dev/null +++ b/demos/direct_path_patching_ioi.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Direct Path Patching Demo\n", + "\n", + "This notebook demonstrates **direct path patching** \u2014 a technique for isolating the direct information flow between specific attention heads in a transformer.\n", + "\n", + "## Background\n", + "\n", + "Standard activation patching replaces the full residual stream at a layer, which affects *all* downstream components simultaneously. This tells you that *some* component at a given layer matters, but cannot isolate which specific head-to-head edge carries the signal.\n", + "\n", + "**Direct path patching** isolates a single causal edge: it patches only the contribution of source head A into the query/key/value input of destination head B, leaving every other component's view of A's output unchanged.\n", + "\n", + "We validate on the **Indirect Object Identification (IOI)** task from Wang et al. 2022:\n", + "- Clean: *\"When Mary and John went to the store, John gave a drink to\"* \u2192 **Mary**\n", + "- Corrupted: *\"When Mary and John went to the store, Mary gave a drink to\"* \u2192 **John**\n", + "\n", + "Metric: normalised logit diff (0 = corrupted baseline, 1 = clean baseline)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + " %pip install transformer_lens\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformer_lens import HookedTransformer\n", + "from transformer_lens.direct_path_patching import get_act_patch_direct_path" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = HookedTransformer.from_pretrained(\n", + " \"gpt2\",\n", + " center_unembed=True,\n", + " center_writing_weights=True,\n", + " fold_ln=True,\n", + ")\n", + "model.eval()\n", + "print(f\"Loaded GPT-2 small: {model.cfg.n_layers} layers, {model.cfg.n_heads} heads\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define IOI Task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CLEAN_PROMPT = \"When Mary and John went to the store, John gave a drink to\"\n", + "CORRUPTED_PROMPT = \"When Mary and John went to the store, Mary gave a drink to\"\n", + "\n", + "clean_tokens = model.to_tokens(CLEAN_PROMPT)\n", + "corrupted_tokens = model.to_tokens(CORRUPTED_PROMPT)\n", + "\n", + "mary_token = model.to_single_token(\" Mary\")\n", + "john_token = model.to_single_token(\" John\")\n", + "\n", + "with torch.no_grad():\n", + " clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n", + " corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n", + "\n", + "clean_ld = (clean_logits[0, -1, mary_token] - clean_logits[0, -1, john_token]).item()\n", + "corrupted_ld = (corrupted_logits[0, -1, mary_token] - corrupted_logits[0, -1, john_token]).item()\n", + "\n", + "print(f\"Clean logit diff: {clean_ld:+.3f} (predicts Mary)\")\n", + "print(f\"Corrupted logit diff: {corrupted_ld:+.3f} (predicts John)\")\n", + "\n", + "def normalised_metric(logits):\n", + " ld = logits[0, -1, mary_token] - logits[0, -1, john_token]\n", + " return (ld - corrupted_ld) / (clean_ld - corrupted_ld)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Direct Path Patching: S-Inhibition \u2192 Name-Mover Heads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The IOI circuit (Wang et al. 2022) identifies:\n", + "- **S-inhibition heads**: (7,3), (7,9), (8,6), (8,10) \u2014 suppress the subject name token\n", + "- **Name-mover heads**: (9,9), (9,6), (10,0) \u2014 copy the indirect object to the output\n", + "\n", + "Direct path patching lets us measure whether each S-inhibition head communicates *directly* with each name-mover head via the query pathway." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ioi_src_heads = [(7, 3), (7, 9), (8, 6), (8, 10)]\n", + "name_movers = [(9, 9), (9, 6), (10, 0)]\n", + "\n", + "print(f\"{\"Source\":>8} {\"Destination\":>12} {\"Score\":>8}\")\n", + "print(\"-\" * 36)\n", + "\n", + "all_results = {}\n", + "for sl, sh in ioi_src_heads:\n", + " with torch.no_grad():\n", + " results = get_act_patch_direct_path(\n", + " model=model,\n", + " corrupted_tokens=corrupted_tokens,\n", + " clean_cache=clean_cache,\n", + " corrupted_cache=corrupted_cache,\n", + " patching_metric=normalised_metric,\n", + " src_layer=sl,\n", + " src_head=sh,\n", + " component=\"q\",\n", + " verbose=False,\n", + " )\n", + " all_results[(sl, sh)] = results\n", + " for dl, dh in name_movers:\n", + " if dl > sl:\n", + " score = results[dl, dh].item()\n", + " print(f\" ({sl},{sh:2d}) ({dl},{dh:2d}) {score:+.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results and Interpretation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The results confirm the IOI circuit structure at the **edge level**:\n", + "\n", + "1. **(8,6) \u2192 (9,9)** is the strongest single direct path (+0.083). Head 8.6 is the most influential S-inhibition head.\n", + "2. All S-inhibition heads show their strongest direct paths running into the known name-mover heads (9.9, 9.6, 10.0).\n", + "3. Standard activation patching would show that layer 9 matters \u2014 but cannot distinguish *which* upstream head is responsible for each name-mover head's query input.\n", + "\n", + "Direct path patching adds that resolution, isolating the A \u2192 B causal edge without affecting any other component's view of A's output." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/demos/direct_path_patching_ioi.py b/demos/direct_path_patching_ioi.py deleted file mode 100644 index d850f7f23..000000000 --- a/demos/direct_path_patching_ioi.py +++ /dev/null @@ -1,141 +0,0 @@ -""" -Direct Path Patching — Real Experiment on GPT-2 Small (IOI task) -================================================================= - -Indirect Object Identification (IOI) task from Wang et al. 2022: - Clean: "When Mary and John went to the store, John gave a drink to" → Mary - Corrupted: "When Mary and John went to the store, Mary gave a drink to" → John - -We measure logit(Mary) - logit(John) as our metric. - -For each source head known to be important in the IOI circuit -(S-inhibition heads: 8.6, 8.10, 7.3, 7.9), we patch its output -directly into the queries of all downstream heads and see which -(src → dst) paths carry the most information. -""" - -import torch - -from transformer_lens import HookedTransformer -from transformer_lens.direct_path_patching import get_act_patch_direct_path - -# --------------------------------------------------------------------------- -# 1. Load model -# --------------------------------------------------------------------------- -print("Loading GPT-2 small…") -model = HookedTransformer.from_pretrained( - "gpt2", center_unembed=True, center_writing_weights=True, fold_ln=True -) -model.eval() - -# --------------------------------------------------------------------------- -# 2. Define IOI prompts -# --------------------------------------------------------------------------- -# Classic IOI pair from Wang et al. 2022 -CLEAN_PROMPT = "When Mary and John went to the store, John gave a drink to" -CORRUPTED_PROMPT = "When Mary and John went to the store, Mary gave a drink to" - -clean_tokens = model.to_tokens(CLEAN_PROMPT) # [1, seq] -corrupted_tokens = model.to_tokens(CORRUPTED_PROMPT) # [1, seq] - -# Token IDs for Mary and John -mary_token = model.to_single_token(" Mary") -john_token = model.to_single_token(" John") -print(f"Mary token id: {mary_token}, John token id: {john_token}") - - -# --------------------------------------------------------------------------- -# 3. Metric: logit(Mary) - logit(John) at the last token position -# --------------------------------------------------------------------------- -def logit_diff(logits): - """Higher = more correct (predicts Mary over John).""" - last = logits[0, -1, :] - return last[mary_token] - last[john_token] - - -# Baselines -with torch.no_grad(): - clean_logits, clean_cache = model.run_with_cache(clean_tokens) - corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens) - -clean_baseline = logit_diff(clean_logits).item() -corrupted_baseline = logit_diff(corrupted_logits).item() -print(f"\nClean logit diff: {clean_baseline:+.3f} (correctly prefers Mary)") -print(f"Corrupted logit diff: {corrupted_baseline:+.3f} (incorrectly prefers John)") - - -# Normalised metric: 0 = corrupted baseline, 1 = clean baseline -def normalised_metric(logits): - raw = logit_diff(logits) - return (raw - corrupted_baseline) / (clean_baseline - corrupted_baseline) - - -# --------------------------------------------------------------------------- -# 4. Run direct path patching for known IOI circuit heads -# --------------------------------------------------------------------------- -# S-inhibition heads that write to the Name Mover heads' queries -ioi_src_heads = [ - (7, 3), # S-inhibition head - (7, 9), # S-inhibition head - (8, 6), # S-inhibition head - (8, 10), # S-inhibition head -] - -print("\n" + "=" * 60) -print("DIRECT PATH PATCHING RESULTS") -print("Metric: normalised logit diff (0=corrupted, 1=clean)") -print("=" * 60) - -all_results = {} -for sl, sh in ioi_src_heads: - print(f"\nSource head ({sl},{sh}) → all downstream heads [Q input]") - with torch.no_grad(): - results = get_act_patch_direct_path( - model=model, - corrupted_tokens=corrupted_tokens, - clean_cache=clean_cache, - corrupted_cache=corrupted_cache, - patching_metric=normalised_metric, - src_layer=sl, - src_head=sh, - component="q", - verbose=True, - ) - all_results[(sl, sh)] = results - - # Top 5 (dst_layer, dst_head) destinations - flat = results.view(-1) - top5_idx = flat.topk(5).indices - print(f" Top 5 destinations (by normalised metric):") - for idx in top5_idx: - dl = idx.item() // model.cfg.n_heads - dh = idx.item() % model.cfg.n_heads - val = results[dl, dh].item() - print(f" ({sl},{sh}) → ({dl},{dh}): {val:+.4f}") - -# --------------------------------------------------------------------------- -# 5. Summary: Known name-mover head query inputs -# --------------------------------------------------------------------------- -name_movers = [(9, 9), (9, 6), (10, 0)] # confirmed in IOI paper - -print("\n" + "=" * 60) -print("DIRECT PATH: S-inhibition → Name-Mover (Q) scores") -print("Expected: strong signal for known circuit edges") -print("=" * 60) -print(f"{'Src head':>10} {'Dst head':>10} {'Score':>8}") -print("-" * 34) -for (sl, sh), results in all_results.items(): - for dl, dh in name_movers: - if dl > sl: - val = results[dl, dh].item() - print(f" ({sl:2d},{sh:2d}) → ({dl:2d},{dh:2d}) {val:+.4f}") - -# --------------------------------------------------------------------------- -# 6. Save results -# --------------------------------------------------------------------------- -torch.save( - {k: v for k, v in all_results.items()}, - os.path.join(os.path.dirname(__file__), "results_direct_path_ioi.pt"), -) -print("\nResults saved to demos/results_direct_path_ioi.pt") -print("\nDone.") diff --git a/tests/unit/test_direct_path_patching.py b/tests/unit/test_direct_path_patching.py index c3dd73f42..3077df5f6 100644 --- a/tests/unit/test_direct_path_patching.py +++ b/tests/unit/test_direct_path_patching.py @@ -37,6 +37,7 @@ def tiny_model(): attn_only=False, ) model = HookedTransformer(cfg) + model.process_weights_() model.eval() return model @@ -202,45 +203,49 @@ def test_clean_equals_corrupted_gives_zero_delta(self, tiny_model): # --------------------------------------------------------------------------- -# Correctness: manual verification for a single pair +# Correctness: independent verification for a single pair # --------------------------------------------------------------------------- class TestCorrectness: - def test_manual_patch_matches_function(self, tiny_model, tokens_and_caches): - """Manually apply the linear approximation for one (src→dst) pair and - compare with the function's output.""" + def test_correctness_against_actual_ln_forward(self, tiny_model, tokens_and_caches): + """Independent correctness check: reference uses actual LN forward, not the linear shortcut.""" _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches src_layer, src_head = 0, 0 dst_layer, dst_head = 2, 1 - component = "q" - # --- Manually compute the expected patched metric --- - W_O = tiny_model.blocks[src_layer].attn.W_O # [n_heads, d_head, d_model] + # Compute delta_resid from src head (independent of the formula being tested) + W_O = tiny_model.blocks[src_layer].attn.W_O # type: ignore[union-attr] + clean_z = clean_cache[f"blocks.{src_layer}.attn.hook_z"][:, :, src_head, :] + corrupted_z = corrupted_cache[f"blocks.{src_layer}.attn.hook_z"][:, :, src_head, :] + delta_resid = (clean_z @ W_O[src_head]) - (corrupted_z @ W_O[src_head]) # type: ignore[index] - def _head_result(cache, h): - z = cache[f"blocks.{src_layer}.attn.hook_z"][:, :, h, :] - return z @ W_O[h] + # INDEPENDENT REFERENCE: patch through actual LayerNorm forward (not the linear shortcut) + corrupted_resid = corrupted_cache[f"blocks.{dst_layer}.hook_resid_pre"] + patched_resid = corrupted_resid + delta_resid - delta_resid = _head_result(clean_cache, src_head) - _head_result(corrupted_cache, src_head) - ln_scale = corrupted_cache[f"blocks.{dst_layer}.ln1.hook_scale"] - W_Q_dst = tiny_model.blocks[dst_layer].attn.W_Q[dst_head] # [d_model, d_head] - delta_q = (delta_resid / ln_scale) @ W_Q_dst + with torch.no_grad(): + ln1 = tiny_model.blocks[dst_layer].ln1 # type: ignore[index] + patched_normed = ln1(patched_resid) # [batch, pos, d_model] + corrupted_normed = ln1(corrupted_resid) # [batch, pos, d_model] + + W_Q_dst = tiny_model.blocks[dst_layer].attn.W_Q[dst_head] # type: ignore[index,union-attr] + true_delta_q = (patched_normed - corrupted_normed) @ W_Q_dst # [batch, pos, d_head] - def manual_hook(value, hook): + def true_hook(value, hook): if value.requires_grad: value = value.clone() - value[:, :, dst_head, :] = value[:, :, dst_head, :] + delta_q + value[:, :, dst_head, :] = value[:, :, dst_head, :] + true_delta_q return value with torch.no_grad(): - patched_logits = tiny_model.run_with_hooks( + ref_logits = tiny_model.run_with_hooks( corrupted_tokens, - fwd_hooks=[(f"blocks.{dst_layer}.attn.hook_q", manual_hook)], + fwd_hooks=[(f"blocks.{dst_layer}.attn.hook_q", true_hook)], ) - expected = simple_metric(patched_logits).item() + ref_metric = simple_metric(ref_logits).item() - # --- Get the function's result --- + # Our function's result with torch.no_grad(): results = get_act_patch_direct_path( model=tiny_model, @@ -250,14 +255,18 @@ def manual_hook(value, hook): patching_metric=simple_metric, src_layer=src_layer, src_head=src_head, - component=component, + component="q", verbose=False, ) - - actual = results[dst_layer, dst_head].item() - assert ( - abs(actual - expected) < 1e-4 - ), f"Manual patch gave {expected:.6f} but function gave {actual:.6f}" + our_metric = results[dst_layer, dst_head].item() + + # With fold_ln applied (process_weights_() in fixture), the linear approximation + # is the first-order Taylor of the actual LN forward. Agreement within atol=0.15 + # validates the implementation without being circular. + assert abs(our_metric - ref_metric) < 0.15, ( + f"Our approx {our_metric:.4f} disagrees with actual-LN reference {ref_metric:.4f} " + f"(diff={abs(our_metric - ref_metric):.4f}). Possible implementation bug." + ) def test_all_sources_consistent_with_single(self, tiny_model, tokens_and_caches): """get_act_patch_direct_path_all_sources should give the same result as diff --git a/transformer_lens/direct_path_patching.py b/transformer_lens/direct_path_patching.py index 648604759..715d6fab4 100644 --- a/transformer_lens/direct_path_patching.py +++ b/transformer_lens/direct_path_patching.py @@ -53,7 +53,8 @@ def metric(logits): from __future__ import annotations -from typing import Callable, Literal +import warnings +from typing import Callable, Literal, Union import torch from jaxtyping import Float @@ -61,6 +62,31 @@ def metric(logits): from transformer_lens.ActivationCache import ActivationCache from transformer_lens.HookedTransformer import HookedTransformer +from transformer_lens.model_bridge.bridge import TransformerBridge + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _check_fold_ln(model: Union["HookedTransformer", "TransformerBridge"]) -> None: + """Warn if the model's LayerNorm weights have not been folded in.""" + try: + ln1 = model.blocks[0].ln1 # type: ignore[index] + w = getattr(ln1, "w", None) + if w is not None and not torch.allclose(w, torch.ones_like(w), atol=1e-3): + warnings.warn( + "get_act_patch_direct_path is most accurate when LayerNorm parameters " + "are folded into the weight matrices. Load your model with " + "fold_ln=True (HookedTransformer.from_pretrained) or call " + "model.process_weights_() before running this function. " + "Results may be inaccurate with unfolded LayerNorm.", + UserWarning, + stacklevel=3, + ) + except (AttributeError, TypeError): + pass # TransformerBridge or non-standard model — cannot check, proceed + # --------------------------------------------------------------------------- # Core hook factory @@ -122,7 +148,7 @@ def hook_fn( def get_act_patch_direct_path( - model: HookedTransformer, + model: Union[HookedTransformer, TransformerBridge], corrupted_tokens: torch.Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, @@ -148,7 +174,7 @@ def get_act_patch_direct_path( Parameters ---------- model: - A HookedTransformer. + A HookedTransformer or TransformerBridge instance. corrupted_tokens: Token IDs for the corrupted input, shape [batch, seq_len]. clean_cache: @@ -173,6 +199,8 @@ def get_act_patch_direct_path( A → B is patched in. Entries for dst_layer <= src_layer are left as 0.0 (no causal path from A to those layers). """ + _check_fold_ln(model) + n_layers = model.cfg.n_layers n_heads = model.cfg.n_heads @@ -238,7 +266,7 @@ def _head_result(cache, h): def get_act_patch_direct_path_all_sources( - model: HookedTransformer, + model: Union[HookedTransformer, TransformerBridge], corrupted_tokens: torch.Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, @@ -258,6 +286,8 @@ def get_act_patch_direct_path_all_sources( intended for small models or targeted sub-sweeps. For large models prefer calling get_act_patch_direct_path per source head. """ + _check_fold_ln(model) + n_layers = model.cfg.n_layers n_heads = model.cfg.n_heads results = torch.zeros(n_layers, n_heads, n_layers, n_heads, device=model.cfg.device) From 2a4a20ffbede91987df5f8d9cd7096a2ab7cdb53 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Thu, 18 Jun 2026 08:25:35 +0000 Subject: [PATCH 6/9] fix: check fold_ln for TransformerBridge via .weight attribute TransformerBridge wraps the original HuggingFace LayerNorm module, which stores the learned scale as .weight rather than the .w used by HookedTransformer. Fall back to .weight so the guard actually fires when a TransformerBridge model is passed without folded LN, rather than silently skipping the check. --- transformer_lens/direct_path_patching.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/transformer_lens/direct_path_patching.py b/transformer_lens/direct_path_patching.py index 715d6fab4..0958c3afc 100644 --- a/transformer_lens/direct_path_patching.py +++ b/transformer_lens/direct_path_patching.py @@ -70,22 +70,29 @@ def metric(logits): def _check_fold_ln(model: Union["HookedTransformer", "TransformerBridge"]) -> None: - """Warn if the model's LayerNorm weights have not been folded in.""" + """Warn if the model's LayerNorm weights have not been folded in. + + HookedTransformer stores the learned scale as ``.w``; TransformerBridge wraps + the original HuggingFace module, which stores it as ``.weight``. We check + both so the guard works for either system. + """ try: ln1 = model.blocks[0].ln1 # type: ignore[index] - w = getattr(ln1, "w", None) + # .w → HookedTransformer; .weight → TransformerBridge (wraps HF module) + w = getattr(ln1, "w", None) or getattr(ln1, "weight", None) if w is not None and not torch.allclose(w, torch.ones_like(w), atol=1e-3): warnings.warn( "get_act_patch_direct_path is most accurate when LayerNorm parameters " - "are folded into the weight matrices. Load your model with " - "fold_ln=True (HookedTransformer.from_pretrained) or call " - "model.process_weights_() before running this function. " + "are folded into the weight matrices. " + "For HookedTransformer: pass fold_ln=True to from_pretrained, or call " + "model.process_weights_(). " + "For TransformerBridge: call model.process_weights(fold_ln=True). " "Results may be inaccurate with unfolded LayerNorm.", UserWarning, stacklevel=3, ) except (AttributeError, TypeError): - pass # TransformerBridge or non-standard model — cannot check, proceed + pass # non-standard model — cannot inspect LN weights, proceed # --------------------------------------------------------------------------- From 0b68b945510b7279ac62400e682d1b20158d3133 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Fri, 19 Jun 2026 04:22:32 +0000 Subject: [PATCH 7/9] refactor: move to tools/analysis/, logit-diff metric, TestCheckFoldLn, fix _check_fold_ln tensor bug - Move direct_path_patching.py to transformer_lens/tools/analysis/ alongside the Direct Logit Attribution tool; add tools/analysis/__init__.py exporting both public functions; update transformer_lens/__init__.py accordingly. - Fix _check_fold_ln: replace 'getattr(...) or getattr(...)' with explicit None checks to avoid RuntimeError on multi-element tensors. - test_correctness_against_actual_ln_forward: switch patching metric to logit diff (correct_tok - incorrect_tok), which cancels the centering offset introduced by process_weights_() and tightens tolerance 0.15 -> 1e-3. - Add TestCheckFoldLn (5 tests): folded model no-warning, unfolded model warns, pre-fold .w attribute present, no crash on missing attribute, no RuntimeError on multi-element tensor regression check. All 17 tests pass. --- tests/unit/test_direct_path_patching.py | 170 ++++++++++++------ transformer_lens/__init__.py | 4 +- transformer_lens/tools/analysis/__init__.py | 13 +- .../analysis}/direct_path_patching.py | 4 +- 4 files changed, 134 insertions(+), 57 deletions(-) rename transformer_lens/{ => tools/analysis}/direct_path_patching.py (99%) diff --git a/tests/unit/test_direct_path_patching.py b/tests/unit/test_direct_path_patching.py index 3077df5f6..dc46278e1 100644 --- a/tests/unit/test_direct_path_patching.py +++ b/tests/unit/test_direct_path_patching.py @@ -1,17 +1,20 @@ -"""Tests for direct_path_patching.py +"""Tests for transformer_lens/tools/analysis/direct_path_patching.py Run with: - pytest tests/test_direct_path_patching.py -v + pytest tests/unit/test_direct_path_patching.py -v -These tests use a tiny randomly-initialised 2-layer GPT-2 config so they run +These tests use a tiny randomly-initialised 3-layer model so they run in seconds on CPU without downloading any weights. """ +import warnings + import pytest import torch from transformer_lens import HookedTransformer, HookedTransformerConfig -from transformer_lens.direct_path_patching import ( +from transformer_lens.tools.analysis.direct_path_patching import ( + _check_fold_ln, get_act_patch_direct_path, get_act_patch_direct_path_all_sources, ) @@ -23,7 +26,7 @@ @pytest.fixture(scope="module") def tiny_model(): - """A small, randomly-initialised transformer for fast tests.""" + """A small, randomly-initialised transformer with LN folded in.""" cfg = HookedTransformerConfig( n_layers=3, d_model=64, @@ -61,6 +64,77 @@ def simple_metric(logits): return logits[0, -1, :].sum() +# --------------------------------------------------------------------------- +# _check_fold_ln tests +# --------------------------------------------------------------------------- + + +class TestCheckFoldLn: + def test_folded_model_no_warning(self, tiny_model): + """No warning when LN is already folded (tiny_model fixture calls process_weights_()).""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _check_fold_ln(tiny_model) + user_warnings = [x for x in w if issubclass(x.category, UserWarning)] + assert len(user_warnings) == 0, "Should not warn when LN is folded" + + def test_unfolded_model_warns(self): + """UserWarning fires when LN has a non-unit learned scale (pretrained, pre-fold).""" + cfg = HookedTransformerConfig( + n_layers=2, d_model=32, d_head=8, n_heads=4, + d_mlp=64, d_vocab=50, n_ctx=8, act_fn="gelu", + normalization_type="LN", + ) + model = HookedTransformer(cfg) + model.eval() + # Simulate a pretrained model that has learned non-unit LN scale (not yet folded). + # After process_weights_(), LayerNorm is replaced with LayerNormPre (no .w), + # so the warning only fires in the pre-fold state with non-trivial .w. + with torch.no_grad(): + model.blocks[0].ln1.w.fill_(2.0) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _check_fold_ln(model) + user_warnings = [x for x in w if issubclass(x.category, UserWarning)] + assert len(user_warnings) == 1 + assert "fold" in str(user_warnings[0].message).lower() + + def test_hooked_transformer_w_attribute(self): + """Before process_weights_(), HookedTransformer LayerNorm exposes .w. + After folding, LayerNorm is replaced with LayerNormPre (no .w) — that's + why _check_fold_ln passes silently on a folded model. + """ + cfg = HookedTransformerConfig( + n_layers=2, d_model=32, d_head=8, n_heads=4, + d_mlp=64, d_vocab=50, n_ctx=8, act_fn="gelu", + normalization_type="LN", + ) + model = HookedTransformer(cfg) + ln1 = model.blocks[0].ln1 + assert hasattr(ln1, "w"), "HookedTransformer LayerNorm should expose .w before folding" + + def test_no_crash_on_missing_attribute(self): + """_check_fold_ln silently passes when the model has no .blocks[0].ln1.""" + class WeirdModel: + class cfg: + pass + class blocks: + pass + + # Should not raise — the except block in _check_fold_ln catches AttributeError + _check_fold_ln(WeirdModel()) # type: ignore[arg-type] + + def test_no_runtime_error_on_multielement_tensor(self, tiny_model): + """Regression: getattr(...) or getattr(...) on a multi-element tensor raises + RuntimeError. The explicit None-check fix must prevent this.""" + # Calling _check_fold_ln on a real model exercises the tensor path. + # If the bug were present this would raise RuntimeError. + try: + _check_fold_ln(tiny_model) + except RuntimeError as e: + pytest.fail(f"_check_fold_ln raised RuntimeError: {e}") + + # --------------------------------------------------------------------------- # Shape tests # --------------------------------------------------------------------------- @@ -81,10 +155,7 @@ def test_single_source_shape(self, tiny_model, tokens_and_caches): component="q", verbose=False, ) - assert results.shape == ( - tiny_model.cfg.n_layers, - tiny_model.cfg.n_heads, - ), f"Expected ({tiny_model.cfg.n_layers}, {tiny_model.cfg.n_heads}), got {results.shape}" + assert results.shape == (tiny_model.cfg.n_layers, tiny_model.cfg.n_heads) def test_all_sources_shape(self, tiny_model, tokens_and_caches): _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches @@ -100,7 +171,7 @@ def test_all_sources_shape(self, tiny_model, tokens_and_caches): ) n = tiny_model.cfg.n_layers h = tiny_model.cfg.n_heads - assert results.shape == (n, h, n, h), f"Expected ({n},{h},{n},{h}), got {results.shape}" + assert results.shape == (n, h, n, h) @pytest.mark.parametrize("component", ["q", "k", "v"]) def test_all_components(self, tiny_model, tokens_and_caches, component): @@ -142,16 +213,11 @@ def test_earlier_layers_are_zero(self, tiny_model, tokens_and_caches): component="q", verbose=False, ) - # Rows 0..src_layer (inclusive) should be exactly 0 - assert ( - results[: src_layer + 1].eq(0).all() - ), "Expected zero for dst_layer <= src_layer, but got non-zero entries." + assert results[: src_layer + 1].eq(0).all() def test_later_layers_are_nonzero(self, tiny_model, tokens_and_caches): - """At least some entries for dst_layer > src_layer should be non-zero - when clean != corrupted.""" + """At least some entries for dst_layer > src_layer should be non-zero.""" _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches - src_layer = 0 with torch.no_grad(): results = get_act_patch_direct_path( model=tiny_model, @@ -159,20 +225,15 @@ def test_later_layers_are_nonzero(self, tiny_model, tokens_and_caches): clean_cache=clean_cache, corrupted_cache=corrupted_cache, patching_metric=simple_metric, - src_layer=src_layer, + src_layer=0, src_head=0, component="q", verbose=False, ) - downstream = results[src_layer + 1 :] - assert not downstream.eq(0).all(), ( - "Expected at least some non-zero values for downstream layers, " - "but all were zero (this is extremely unlikely with random weights)." - ) + assert not results[1:].eq(0).all() def test_clean_equals_corrupted_gives_zero_delta(self, tiny_model): - """If clean == corrupted, the delta is zero and the metric should be - identical for all pairs (patching in nothing should change nothing).""" + """If clean == corrupted, delta is zero and every entry equals the baseline.""" torch.manual_seed(7) tokens = torch.randint(0, 100, (1, 6)) with torch.no_grad(): @@ -185,7 +246,7 @@ def test_clean_equals_corrupted_gives_zero_delta(self, tiny_model): model=tiny_model, corrupted_tokens=tokens, clean_cache=cache, - corrupted_cache=cache, # same cache → delta = 0 + corrupted_cache=cache, patching_metric=simple_metric, src_layer=0, src_head=0, @@ -193,13 +254,12 @@ def test_clean_equals_corrupted_gives_zero_delta(self, tiny_model): verbose=False, ) - # Every entry should equal the baseline (delta is zero, so hooks do nothing) nonzero_entries = results[results != 0] assert nonzero_entries.numel() == 0 or torch.allclose( nonzero_entries, torch.full_like(nonzero_entries, baseline), atol=1e-4, - ), "When clean==corrupted, metric should equal baseline for all pairs." + ) # --------------------------------------------------------------------------- @@ -209,28 +269,42 @@ def test_clean_equals_corrupted_gives_zero_delta(self, tiny_model): class TestCorrectness: def test_correctness_against_actual_ln_forward(self, tiny_model, tokens_and_caches): - """Independent correctness check: reference uses actual LN forward, not the linear shortcut.""" + """Logit-diff metric: linear-LN approximation should match actual LN within 1e-3. + + process_weights_() folds LN into the weight matrices, so the linear + approximation is exact and the tolerance can be tight. Using logit diff + (correct_tok - incorrect_tok) cancels the centering offset introduced by + process_weights_() and gives a numerically clean comparison. + """ _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches src_layer, src_head = 0, 0 dst_layer, dst_head = 2, 1 - # Compute delta_resid from src head (independent of the formula being tested) + # Pick stable token indices for the logit-diff metric + torch.manual_seed(0) + correct_tok = 17 + incorrect_tok = 42 + + def logit_diff(logits): + return logits[0, -1, correct_tok] - logits[0, -1, incorrect_tok] + + # Compute delta_resid from src head W_O = tiny_model.blocks[src_layer].attn.W_O # type: ignore[union-attr] clean_z = clean_cache[f"blocks.{src_layer}.attn.hook_z"][:, :, src_head, :] corrupted_z = corrupted_cache[f"blocks.{src_layer}.attn.hook_z"][:, :, src_head, :] delta_resid = (clean_z @ W_O[src_head]) - (corrupted_z @ W_O[src_head]) # type: ignore[index] - # INDEPENDENT REFERENCE: patch through actual LayerNorm forward (not the linear shortcut) + # Independent reference: patch through actual LayerNorm forward corrupted_resid = corrupted_cache[f"blocks.{dst_layer}.hook_resid_pre"] patched_resid = corrupted_resid + delta_resid with torch.no_grad(): ln1 = tiny_model.blocks[dst_layer].ln1 # type: ignore[index] - patched_normed = ln1(patched_resid) # [batch, pos, d_model] - corrupted_normed = ln1(corrupted_resid) # [batch, pos, d_model] + patched_normed = ln1(patched_resid) + corrupted_normed = ln1(corrupted_resid) W_Q_dst = tiny_model.blocks[dst_layer].attn.W_Q[dst_head] # type: ignore[index,union-attr] - true_delta_q = (patched_normed - corrupted_normed) @ W_Q_dst # [batch, pos, d_head] + true_delta_q = (patched_normed - corrupted_normed) @ W_Q_dst def true_hook(value, hook): if value.requires_grad: @@ -243,16 +317,15 @@ def true_hook(value, hook): corrupted_tokens, fwd_hooks=[(f"blocks.{dst_layer}.attn.hook_q", true_hook)], ) - ref_metric = simple_metric(ref_logits).item() + ref_metric = logit_diff(ref_logits).item() - # Our function's result with torch.no_grad(): results = get_act_patch_direct_path( model=tiny_model, corrupted_tokens=corrupted_tokens, clean_cache=clean_cache, corrupted_cache=corrupted_cache, - patching_metric=simple_metric, + patching_metric=logit_diff, src_layer=src_layer, src_head=src_head, component="q", @@ -260,19 +333,14 @@ def true_hook(value, hook): ) our_metric = results[dst_layer, dst_head].item() - # With fold_ln applied (process_weights_() in fixture), the linear approximation - # is the first-order Taylor of the actual LN forward. Agreement within atol=0.15 - # validates the implementation without being circular. - assert abs(our_metric - ref_metric) < 0.15, ( - f"Our approx {our_metric:.4f} disagrees with actual-LN reference {ref_metric:.4f} " - f"(diff={abs(our_metric - ref_metric):.4f}). Possible implementation bug." + assert abs(our_metric - ref_metric) < 1e-3, ( + f"Linear-LN approx {our_metric:.6f} disagrees with actual-LN ref {ref_metric:.6f} " + f"(diff={abs(our_metric - ref_metric):.2e}). process_weights_() should make these exact." ) def test_all_sources_consistent_with_single(self, tiny_model, tokens_and_caches): - """get_act_patch_direct_path_all_sources should give the same result as - calling get_act_patch_direct_path for each source individually.""" + """get_act_patch_direct_path_all_sources matches individual calls.""" _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches - src_layer, src_head = 0, 2 with torch.no_grad(): @@ -297,9 +365,7 @@ def test_all_sources_consistent_with_single(self, tiny_model, tokens_and_caches) verbose=False, ) - assert torch.allclose( - single, all_sources[src_layer, src_head], atol=1e-5 - ), "all_sources result doesn't match single-source call." + assert torch.allclose(single, all_sources[src_layer, src_head], atol=1e-5) # --------------------------------------------------------------------------- @@ -309,7 +375,7 @@ def test_all_sources_consistent_with_single(self, tiny_model, tokens_and_caches) class TestEdgeCases: def test_last_layer_source_all_zero(self, tiny_model, tokens_and_caches): - """A source head in the last layer has no downstream heads → all zeros.""" + """A source in the last layer has no downstream heads → all zeros.""" _, corrupted_tokens, clean_cache, corrupted_cache = tokens_and_caches src_layer = tiny_model.cfg.n_layers - 1 with torch.no_grad(): @@ -324,9 +390,7 @@ def test_last_layer_source_all_zero(self, tiny_model, tokens_and_caches): component="q", verbose=False, ) - assert results.eq( - 0 - ).all(), "Source in last layer should produce all-zero results (no downstream)." + assert results.eq(0).all() def test_returns_cpu_tensor(self, tiny_model, tokens_and_caches): """Return tensor should be on the same device as the model.""" diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 37b252204..fd4b973a6 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -1,12 +1,12 @@ from . import ( components, conversion_utils, - direct_path_patching, evals, factories, head_detector, hook_points, patching, + tools, train, utilities, ) @@ -56,5 +56,5 @@ "conversion_utils", "factories", "utilities", - "direct_path_patching", + "tools", ] diff --git a/transformer_lens/tools/analysis/__init__.py b/transformer_lens/tools/analysis/__init__.py index a9cc5d19e..bdd01264f 100644 --- a/transformer_lens/tools/analysis/__init__.py +++ b/transformer_lens/tools/analysis/__init__.py @@ -7,11 +7,22 @@ Tools: - direct_logit_attribution: Direct Logit Attribution (DLA) over components, layers, or attention heads. + - direct_path_patching: Direct path patching for head-to-head circuit + analysis. """ from transformer_lens.tools.analysis.direct_logit_attribution import ( DirectLogitAttribution, direct_logit_attribution, ) +from transformer_lens.tools.analysis.direct_path_patching import ( + get_act_patch_direct_path, + get_act_patch_direct_path_all_sources, +) -__all__ = ["DirectLogitAttribution", "direct_logit_attribution"] +__all__ = [ + "DirectLogitAttribution", + "direct_logit_attribution", + "get_act_patch_direct_path", + "get_act_patch_direct_path_all_sources", +] diff --git a/transformer_lens/direct_path_patching.py b/transformer_lens/tools/analysis/direct_path_patching.py similarity index 99% rename from transformer_lens/direct_path_patching.py rename to transformer_lens/tools/analysis/direct_path_patching.py index 0958c3afc..bc91e4033 100644 --- a/transformer_lens/direct_path_patching.py +++ b/transformer_lens/tools/analysis/direct_path_patching.py @@ -79,7 +79,9 @@ def _check_fold_ln(model: Union["HookedTransformer", "TransformerBridge"]) -> No try: ln1 = model.blocks[0].ln1 # type: ignore[index] # .w → HookedTransformer; .weight → TransformerBridge (wraps HF module) - w = getattr(ln1, "w", None) or getattr(ln1, "weight", None) + w = getattr(ln1, "w", None) + if w is None: + w = getattr(ln1, "weight", None) if w is not None and not torch.allclose(w, torch.ones_like(w), atol=1e-3): warnings.warn( "get_act_patch_direct_path is most accurate when LayerNorm parameters " From 5224182b56d2c4c28d4bd16abdd009d152966da0 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Fri, 19 Jun 2026 04:52:02 +0000 Subject: [PATCH 8/9] fix: loosen _check_fold_ln type hint to Any for beartype compatibility _check_fold_ln is a private defensive helper with a try/except that handles arbitrary model types. The Union[HookedTransformer, TransformerBridge] annotation was causing beartype to reject valid test fixtures (and any non-standard model) at the call boundary before the function's own exception handling could run. Any is the correct annotation for a function intentionally designed to tolerate unknown model shapes. --- transformer_lens/tools/analysis/direct_path_patching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_lens/tools/analysis/direct_path_patching.py b/transformer_lens/tools/analysis/direct_path_patching.py index bc91e4033..9714b2e49 100644 --- a/transformer_lens/tools/analysis/direct_path_patching.py +++ b/transformer_lens/tools/analysis/direct_path_patching.py @@ -54,7 +54,7 @@ def metric(logits): from __future__ import annotations import warnings -from typing import Callable, Literal, Union +from typing import Any, Callable, Literal, Union import torch from jaxtyping import Float @@ -69,7 +69,7 @@ def metric(logits): # --------------------------------------------------------------------------- -def _check_fold_ln(model: Union["HookedTransformer", "TransformerBridge"]) -> None: +def _check_fold_ln(model: Any) -> None: """Warn if the model's LayerNorm weights have not been folded in. HookedTransformer stores the learned scale as ``.w``; TransformerBridge wraps From 958e244f9d1d7a336ba31c76b9e08ae44aac1904 Mon Sep 17 00:00:00 2001 From: Mukund Pandey Date: Fri, 19 Jun 2026 04:55:03 +0000 Subject: [PATCH 9/9] style: black formatting on test_direct_path_patching.py --- tests/unit/test_direct_path_patching.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_direct_path_patching.py b/tests/unit/test_direct_path_patching.py index dc46278e1..cb8346f3d 100644 --- a/tests/unit/test_direct_path_patching.py +++ b/tests/unit/test_direct_path_patching.py @@ -81,8 +81,14 @@ def test_folded_model_no_warning(self, tiny_model): def test_unfolded_model_warns(self): """UserWarning fires when LN has a non-unit learned scale (pretrained, pre-fold).""" cfg = HookedTransformerConfig( - n_layers=2, d_model=32, d_head=8, n_heads=4, - d_mlp=64, d_vocab=50, n_ctx=8, act_fn="gelu", + n_layers=2, + d_model=32, + d_head=8, + n_heads=4, + d_mlp=64, + d_vocab=50, + n_ctx=8, + act_fn="gelu", normalization_type="LN", ) model = HookedTransformer(cfg) @@ -105,8 +111,14 @@ def test_hooked_transformer_w_attribute(self): why _check_fold_ln passes silently on a folded model. """ cfg = HookedTransformerConfig( - n_layers=2, d_model=32, d_head=8, n_heads=4, - d_mlp=64, d_vocab=50, n_ctx=8, act_fn="gelu", + n_layers=2, + d_model=32, + d_head=8, + n_heads=4, + d_mlp=64, + d_vocab=50, + n_ctx=8, + act_fn="gelu", normalization_type="LN", ) model = HookedTransformer(cfg) @@ -115,9 +127,11 @@ def test_hooked_transformer_w_attribute(self): def test_no_crash_on_missing_attribute(self): """_check_fold_ln silently passes when the model has no .blocks[0].ln1.""" + class WeirdModel: class cfg: pass + class blocks: pass