Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9deb6bf
Merge pull request #1370 from danra/patch-1
danra Jun 8, 2026
de181e2
Add Direct Logit Attribution tool for TransformerBridge (#1316)
TravisHaa Jun 8, 2026
9399fcd
Add Direct Logit Attribution tool (#1263) (#1369)
azrabano23 Jun 8, 2026
75095b1
Merge remote-tracking branch 'origin/main' into dev
jlarson4 Jun 8, 2026
a5f1193
Add stop_strings and stopping_criteria support to TransformerBridge.g…
RecreationalMath Jun 9, 2026
d37642d
Remove extra checks from Phi adapter setup_component_testing (#1373)
TensorCruncher Jun 9, 2026
35ab438
Add phi tests (#1372)
TensorCruncher Jun 9, 2026
34dc38a
Fixed SVD interpreter test (#1375)
has9800 Jun 9, 2026
f3a0ce4
Fix full-loss comparison cell in Grokking demo (#1378)
robbiebusinessacc Jun 11, 2026
d4e1800
Remove unused eps_attr config field (#1379)
RecreationalMath Jun 11, 2026
036a861
Fix typos and narrow a bare except (#1380)
RecreationalMath Jun 15, 2026
8c395ee
Add unit tests for NeoArchitectureAdapter (#1381)
chandrudp29 Jun 15, 2026
d6896df
Add unit tests for NeoxArchitectureAdapter (#1382)
chandrudp29 Jun 15, 2026
e49f78c
Add Olmo2 architecture adapter tests (#1387)
RecreationalMath Jun 15, 2026
d603a1d
Add Qwen Adapter unit tests (#1388)
has9800 Jun 15, 2026
5962cd2
Add unit tests for OpenElmArchitectureAdapter (#1383)
chandrudp29 Jun 15, 2026
b682377
Add unit tests for LlavaOnevisionArchitectureAdapter (#1384)
chandrudp29 Jun 15, 2026
3e45c5e
Use ParallelBlockBridge for StableLM parallel_attn_mlp=True branch (#…
RecreationalMath Jun 15, 2026
128723e
Add unit tests for LlamaArchitectureAdapter (#1391)
chandrudp29 Jun 15, 2026
b041d9d
Main demo notebook maintenance (#1389)
danra Jun 16, 2026
1cdd921
Add unit tests for MistralArchitectureAdapter (#1392)
chandrudp29 Jun 16, 2026
09f2eba
Add StableLM architecture adapter tests (#1393)
RecreationalMath Jun 16, 2026
56c3d91
Remove torch cap, so that newer versions of python can still resolve …
jlarson4 Jun 16, 2026
6a17449
Updating Agentic Workflows (#1395)
jlarson4 Jun 17, 2026
8691fa3
Drop round-trip and output-shape tests per the unit-test guide (#1397)
RecreationalMath Jun 18, 2026
f21898e
test: add adapter unit tests for Phi-3 and Granite (+ GraniteMoe) (#1…
mukund1985 Jun 18, 2026
cae9d46
Direct path patch demo (#1398)
danra Jun 19, 2026
5b51876
fix(benchmarks): remove hooks via HookPoint instead of bogus add_hook…
danra Jun 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Python: **>=3.10, <4.0**. CI tests 3.10, 3.11, 3.12. Format/type/docstring check
| [transformer_lens/hook_points.py](transformer_lens/hook_points.py) | `HookPoint` class and `LensHandle` |
| [transformer_lens/supported_models.py](transformer_lens/supported_models.py) | **HT-only** registry (`OFFICIAL_MODEL_NAMES`, `MODEL_ALIASES`) |
| [transformer_lens/tools/model_registry/](transformer_lens/tools/model_registry/) | Bridge-side registry + `verify_models.py` benchmark suite |
| [transformer_lens/tools/analysis/](transformer_lens/tools/analysis/) | High-level single-call analyses over the cache (e.g. `direct_logit_attribution`); works with both HT and Bridge |
| [transformer_lens/patching.py](transformer_lens/patching.py), [evals.py](transformer_lens/evals.py) | Activation patching, IOI, ROME, etc. |
| [tests/unit/](tests/unit/), [tests/integration/](tests/integration/), [tests/acceptance/](tests/acceptance/), [tests/benchmarks/](tests/benchmarks/), [tests/mps/](tests/mps/) | Test tiers |
| [demos/](demos/) | Jupyter notebooks; a subset runs in CI under `nbval` with sanitization from [demos/doc_sanitize.cfg](demos/doc_sanitize.cfg) |
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ exploratory research!

### Creator's Note (Neel Nanda)

I (Neel Nanda) used to work for the [Anthropic interpretability team](transformer-circuits.pub), and
I (Neel Nanda) used to work for the [Anthropic interpretability team](https://transformer-circuits.pub), and
I wrote this library because after I left and tried doing independent research, I got extremely
frustrated by the state of open source tooling. There's a lot of excellent infrastructure like
HuggingFace and DeepSpeed to _use_ or _train_ models, but very little to dig into their internals
Expand Down
243 changes: 236 additions & 7 deletions demos/Exploratory_Analysis_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Detect if we're running in Google Colab\n",
"try:\n",
Expand All @@ -79,16 +80,16 @@
"if IN_COLAB:\n",
" %pip install transformer_lens\n",
" %pip install circuitsvis\n",
" # Install a faster Node version\n",
" !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs # noqa\n",
"\n",
"# Hot reload in development mode & not running on the CD\n",
"if not IN_COLAB:\n",
" from IPython import get_ipython\n",
" ip = get_ipython()\n",
" if not ip.extension_manager.loaded:\n",
" ip.extension_manager.load('autoreload')\n",
" %autoreload 2\n"
" %autoreload 2\n",
"\n",
"IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\""
]
},
{
Expand Down Expand Up @@ -166,13 +167,16 @@
"metadata": {},
"outputs": [],
"source": [
"def imshow(tensor, **kwargs):\n",
" px.imshow(\n",
"def imshow(tensor, show=True, **kwargs):\n",
" fig = px.imshow(\n",
" utils.to_numpy(tensor),\n",
" color_continuous_midpoint=0.0,\n",
" color_continuous_scale=\"RdBu\",\n",
" **kwargs,\n",
" ).show()\n",
" )\n",
" if show:\n",
" fig.show()\n",
" return fig\n",
"\n",
"\n",
"def line(tensor, **kwargs):\n",
Expand Down Expand Up @@ -255,6 +259,7 @@
" fold_ln=True,\n",
" refactor_factored_attn_matrices=True,\n",
")\n",
"model.set_use_attn_result(True)\n",
"\n",
"# Get the default device used\n",
"device: torch.device = utils.get_device()"
Expand Down Expand Up @@ -1389,6 +1394,230 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Direct Path Patching\n",
"\n",
"Decomposing attention heads helped us understand whether their outputs are affected more by their computed attention pattern or values, but we've looked at each head in isolation. We can try taking this one step further by attempting to find how one attention head's output affects another's attention pattern and values. This could start revealing the circuitry of how information is moved across layers!\n",
"\n",
"One mechanism we can use is Direct Path Patching. Path patching is like activation patching, but rather than patching an activation we patch the *effect* of one activation on another, later activation. *Direct* path patching, specifically, means patching the linear part of that effect, i.e., the part that passes through the residual stream from the first activation to the later one, without going through attention heads or MLPs (see this [ARENA notebook](https://colab.research.google.com/drive/1KgrEwvCKdX-8DQ1uSiIuxwIiwzJuQ3Gw#scrollTo=b32Qdk-Gl6mU) for an alternative definition which does include MLPs).\n",
"\n",
"For example, to do direct path patching from the output of attention head `A` to the query of attention head `B` in a later layer `L`, we'd add a hook saying:\n",
"```py\n",
"patched_B_query = corrupted_B_query + (clean_A_output - corrupted_A_output) @ W_Q(B) / corrupted_ln1_scale(L)\n",
"```\n",
"\n",
"<details>\n",
"<summary>Why the corrupted ln1?</summary>\n",
"Since all earlier components in the run are corrupt, we apply the cached layer norm from the corrupted run rather than that of the clean one to get a better approximation of the linear effect (see \"Ignoring LayerNorm\" above for why dividing by the cached LayerNorm makes sense in the first place).\n",
"</details>\n",
"<br/>\n",
"\n",
"We'll look at the effects of direct path patching from outputs of attention heads to query, key and value vectors of following attention heads. As before, we patch the effects across all token positions, although it's possible to further zoom in to specific token positions (notably the last one and the second subject one).\n",
"\n",
"<details>\n",
"<summary>More on token positions</summary>\n",
"Note that direct path patching can only meaningfully be done for activations in the same token position; any path between two activations in two different token positions must go through attention, hence it is not a direct path. Still, patching a direct path from the output of an attention head in a specific token position to either the key or value vectors of a later attention head attending to the same token position could affect that head's output in other querying token positions.\n",
"</details>\n",
"<br/>\n",
"\n",
"First, the implementation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def patch_direct_path_A_output_to_B_attn_act(\n",
" corrupted_attn_act: Float[torch.Tensor, \"batch pos head_index d_head\"],\n",
" hook,\n",
" attn_act_type,\n",
" A_layer,\n",
" A_head_index,\n",
" B_layer,\n",
" B_head_index,\n",
" clean_cache,\n",
" corrupted_cache,\n",
"):\n",
" A_output_act_name = utils.get_act_name(\"result\", A_layer, \"attn\")\n",
" clean_A_output = clean_cache[A_output_act_name][:, :, A_head_index, :]\n",
" corrupted_A_output = corrupted_cache[A_output_act_name][:, :, A_head_index, :]\n",
" corrupted_B_act = corrupted_attn_act[:, :, B_head_index, :]\n",
" W_B = _get_attention_weights(model.blocks[B_layer].attn, attn_act_type)[B_head_index]\n",
" corrupted_ln1_scale = corrupted_cache[utils.get_act_name(\"scale\", B_layer, \"ln1\")]\n",
"\n",
" patched_B_act = corrupted_B_act + (clean_A_output - corrupted_A_output) @ W_B / corrupted_ln1_scale\n",
"\n",
" patched_attn_act = corrupted_attn_act.clone()\n",
" patched_attn_act[:, :, B_head_index, :] = patched_B_act\n",
" return patched_attn_act\n",
"\n",
"\n",
"def direct_path_output_to_attn_act_diffs(output_layer, attn_act_layer, attn_act_type):\n",
" A_to_B_diffs = torch.zeros(model.cfg.n_heads, model.cfg.n_heads, device=device, dtype=torch.float32)\n",
" for A in range(model.cfg.n_heads):\n",
" for B in range(model.cfg.n_heads):\n",
" patch_hook = (\n",
" utils.get_act_name(attn_act_type, attn_act_layer, \"attn\"),\n",
" partial(\n",
" patch_direct_path_A_output_to_B_attn_act,\n",
" attn_act_type=attn_act_type,\n",
" A_layer=output_layer,\n",
" A_head_index=A,\n",
" B_layer=attn_act_layer,\n",
" B_head_index=B,\n",
" clean_cache=cache,\n",
" corrupted_cache=corrupted_cache,\n",
" ),\n",
" )\n",
" # It would have been great to set start_layer=B and the inputs to the cached corrupted residual stream at that layer;\n",
" # unfortunately, this is not supported by TransformerBridge. Legacy HookedTransformer supports it but is slower, so\n",
" # it doesn't pay off to switch.\n",
" patched_logits = model.run_with_hooks(\n",
" corrupted_tokens,\n",
" fwd_hooks=[patch_hook],\n",
" return_type=\"logits\",\n",
" )\n",
" patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n",
" A_to_B_diffs[A, B] = normalize_patched_logit_diff(patched_logit_diff)\n",
" return A_to_B_diffs\n",
"\n",
"\n",
"def _get_attention_weights(attn, attn_act_type):\n",
" if attn_act_type == \"q\":\n",
" return attn.W_Q\n",
" if attn_act_type == \"k\":\n",
" return attn.W_K\n",
" if attn_act_type == \"v\":\n",
" return attn.W_V\n",
" raise ValueError"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def show_direct_path_output_to_attn_act_diffs(diffs, output_layer, attn_act_layer, attn_act_type, zrange=0.3):\n",
" attn_act_name = _get_attention_act_type_name(attn_act_type)\n",
" fig = imshow(\n",
" diffs,\n",
" show=False,\n",
" title=f\"Direct Path Patch Logit Difference, L{output_layer} Output to L{attn_act_layer} {attn_act_name}\",\n",
" labels={\"x\": f\"{attn_act_name} Head\", \"y\": \"Output Head\"},\n",
" zmin=-zrange,\n",
" zmax=zrange,\n",
" )\n",
" fig.update_traces(\n",
" hovertemplate=f\"L{output_layer}H%{{y}} -> L{attn_act_layer}H%{{x}}<br>%{{z:.3f}}<extra></extra>\"\n",
" )\n",
" fig.show()\n",
"\n",
"\n",
"def _get_attention_act_type_name(attn_act_type):\n",
" if attn_act_type == \"q\":\n",
" return \"Query\"\n",
" if attn_act_type == \"k\":\n",
" return \"Key\"\n",
" if attn_act_type == \"v\":\n",
" return \"Value\"\n",
" raise ValueError"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Testing that it works:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_direct_path_patching():\n",
" output_layer, value_layer = (5, 8)\n",
" diffs = direct_path_output_to_attn_act_diffs(output_layer, value_layer, \"v\")\n",
" show_direct_path_output_to_attn_act_diffs(diffs, output_layer, value_layer, \"v\")\n",
"\n",
"\n",
"test_direct_path_patching()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that L5H5's output directly affects L8H6's attention value vector. We'd previously seen that the output of either head is significant to correct prediction, but now we can further conclude that they perform (at least some of) their function together rather than in isolation!\n",
"\n",
"Moving on to queries: The previous section showed how patching the attention patterns of late-layer heads affects their outputs significantly. We could hypothesize that the computed query vectors of these heads are different in the corrupted vs. clean runs because of information that the mid-layer heads are generating. If that's the case and the effect is direct (passes through the residual stream), then direct path patching might find it! The time to iterate pairs of layers increases quadratically with the number of layers, so with this hypothesis in mind we can focus on just the mid-late layers range, and we can also limit looking up to two layers ahead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not IN_GITHUB:\n",
" for output_layer in range(7, model.cfg.n_layers - 1):\n",
" for query_layer in range(output_layer + 1, min(output_layer + 3, model.cfg.n_layers)):\n",
" diffs = direct_path_output_to_attn_act_diffs(output_layer, query_layer, \"q\")\n",
" show_direct_path_output_to_attn_act_diffs(diffs, output_layer, query_layer, \"q\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Two diffs that stand out are that patching L8H6's output direct path to L9H9's query improves next token prediction, whereas patching L9H9's output direct paths to both L10H7's and L11H10's queries worsens it. Previously, we'd already seen that patching L8H6 and L9H9 outputs improves prediction and that patching L10H7 and L11H10 outputs worsens it, and that the computed attention patterns for L9H9, L10H7 and L11H10 play a significant role; we can now further conclude that L8H6's output directly affects L9H9's attention pattern significantly, and that L9H9's output directly (negatively) affects L10H7's and L11H10's attention patterns significantly, via the computed query vectors.\n",
"\n",
"What about keys? We can guess that outputs of earlier attention heads set up keys on token positions to determine what later heads attend to:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not IN_GITHUB:\n",
" for output_layer in range(5, 8):\n",
" for key_layer in range(8, model.cfg.n_layers):\n",
" diffs = direct_path_output_to_attn_act_diffs(output_layer, key_layer, \"k\")\n",
" show_direct_path_output_to_attn_act_diffs(diffs, output_layer, key_layer, \"k\", zrange=0.03)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We do find direct effects of outputs on keys, albeit at a much lower scale than on queries and values. The most significant ones are from L6H9 and L7H1 on L8H6.\n",
"\n",
"Note that it's possible stronger effects of outputs on keys exist which are not revealed, even on the layers that we did test: there might be some indirect effects, or possibly additional direct effects which are masked by some other model behavior, e.g., a later corrupted component could remove such an effect.\n",
"\n",
"Combining the results of the above experiments, we get an outline of some circuitry:\n",
"\n",
"```\n",
" 5.5 6.9,7.1\n",
" | |\n",
" v | | k\n",
" \\ /\n",
" 8.6 --> 9.9 --> 10.7,11.10\n",
" q q\n",
"```\n",
"\n",
"cf. the paper diagram below."
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -1471,7 +1700,7 @@
"Breaking down their categories:\n",
"\n",
"* Early: The duplicate token heads, previous token heads and induction heads. These serve the purpose of detecting that the second subject is duplicated and which earlier name is the duplicate.\n",
" * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H0 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n",
" * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H5 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n",
" * Note that the duplicate token heads and induction heads do not compose with each other - both directly add to the S-Inhibition heads. The diagram is somewhat misleading.\n",
"* Middle: They call these S-Inhibition heads - they copy the information about the duplicate token from the second subject to the to token, and their output is used to *inhibit* the attention paid from the name movers to the first subject copy. We found all these heads, and had a decent guess for what they did.\n",
" * In either case they attend to the second subject, so the patch that mattered was their value vectors!\n",
Expand Down
17 changes: 2 additions & 15 deletions demos/Grokking_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3492,22 +3492,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_1215793/3004607503.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/tmp/ipykernel_1215793/4096650173.py\u001b[0m in \u001b[0;36mloss_fn\u001b[0;34m(logits, labels)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlog_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mcorrect_log_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcorrect_log_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mtrain_logits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1"
]
}
],
"outputs": [],
"source": [
"print(loss_fn(all_logits, labels)) # This bugged on models not fully trained "
"print(loss_fn(original_logits, labels))"
]
},
{
Expand Down
Loading
Loading