Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions demos/direct_path_patching_ioi.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/direct_path_patching_ioi.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Direct Path Patching Demo\n",
"\n",
"This notebook demonstrates **direct path patching** \u2014 a technique for isolating the direct information flow between specific attention heads in a transformer.\n",
"\n",
"## Background\n",
"\n",
"Standard activation patching replaces the full residual stream at a layer, which affects *all* downstream components simultaneously. This tells you that *some* component at a given layer matters, but cannot isolate which specific head-to-head edge carries the signal.\n",
"\n",
"**Direct path patching** isolates a single causal edge: it patches only the contribution of source head A into the query/key/value input of destination head B, leaving every other component's view of A's output unchanged.\n",
"\n",
"We validate on the **Indirect Object Identification (IOI)** task from Wang et al. 2022:\n",
"- Clean: *\"When Mary and John went to the store, John gave a drink to\"* \u2192 **Mary**\n",
"- Corrupted: *\"When Mary and John went to the store, Mary gave a drink to\"* \u2192 **John**\n",
"\n",
"Metric: normalised logit diff (0 = corrupted baseline, 1 = clean baseline)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"import os\n",
"try:\n",
" import google.colab\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
" %pip install transformer_lens\n",
"except:\n",
" IN_COLAB = False\n",
" print(\"Running as a Jupyter notebook\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformer_lens import HookedTransformer\n",
"from transformer_lens.direct_path_patching import get_act_patch_direct_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = HookedTransformer.from_pretrained(\n",
" \"gpt2\",\n",
" center_unembed=True,\n",
" center_writing_weights=True,\n",
" fold_ln=True,\n",
")\n",
"model.eval()\n",
"print(f\"Loaded GPT-2 small: {model.cfg.n_layers} layers, {model.cfg.n_heads} heads\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define IOI Task"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CLEAN_PROMPT = \"When Mary and John went to the store, John gave a drink to\"\n",
"CORRUPTED_PROMPT = \"When Mary and John went to the store, Mary gave a drink to\"\n",
"\n",
"clean_tokens = model.to_tokens(CLEAN_PROMPT)\n",
"corrupted_tokens = model.to_tokens(CORRUPTED_PROMPT)\n",
"\n",
"mary_token = model.to_single_token(\" Mary\")\n",
"john_token = model.to_single_token(\" John\")\n",
"\n",
"with torch.no_grad():\n",
" clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n",
" corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n",
"\n",
"clean_ld = (clean_logits[0, -1, mary_token] - clean_logits[0, -1, john_token]).item()\n",
"corrupted_ld = (corrupted_logits[0, -1, mary_token] - corrupted_logits[0, -1, john_token]).item()\n",
"\n",
"print(f\"Clean logit diff: {clean_ld:+.3f} (predicts Mary)\")\n",
"print(f\"Corrupted logit diff: {corrupted_ld:+.3f} (predicts John)\")\n",
"\n",
"def normalised_metric(logits):\n",
" ld = logits[0, -1, mary_token] - logits[0, -1, john_token]\n",
" return (ld - corrupted_ld) / (clean_ld - corrupted_ld)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Direct Path Patching: S-Inhibition \u2192 Name-Mover Heads"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The IOI circuit (Wang et al. 2022) identifies:\n",
"- **S-inhibition heads**: (7,3), (7,9), (8,6), (8,10) \u2014 suppress the subject name token\n",
"- **Name-mover heads**: (9,9), (9,6), (10,0) \u2014 copy the indirect object to the output\n",
"\n",
"Direct path patching lets us measure whether each S-inhibition head communicates *directly* with each name-mover head via the query pathway."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ioi_src_heads = [(7, 3), (7, 9), (8, 6), (8, 10)]\n",
"name_movers = [(9, 9), (9, 6), (10, 0)]\n",
"\n",
"print(f\"{\"Source\":>8} {\"Destination\":>12} {\"Score\":>8}\")\n",
"print(\"-\" * 36)\n",
"\n",
"all_results = {}\n",
"for sl, sh in ioi_src_heads:\n",
" with torch.no_grad():\n",
" results = get_act_patch_direct_path(\n",
" model=model,\n",
" corrupted_tokens=corrupted_tokens,\n",
" clean_cache=clean_cache,\n",
" corrupted_cache=corrupted_cache,\n",
" patching_metric=normalised_metric,\n",
" src_layer=sl,\n",
" src_head=sh,\n",
" component=\"q\",\n",
" verbose=False,\n",
" )\n",
" all_results[(sl, sh)] = results\n",
" for dl, dh in name_movers:\n",
" if dl > sl:\n",
" score = results[dl, dh].item()\n",
" print(f\" ({sl},{sh:2d}) ({dl},{dh:2d}) {score:+.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Results and Interpretation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The results confirm the IOI circuit structure at the **edge level**:\n",
"\n",
"1. **(8,6) \u2192 (9,9)** is the strongest single direct path (+0.083). Head 8.6 is the most influential S-inhibition head.\n",
"2. All S-inhibition heads show their strongest direct paths running into the known name-mover heads (9.9, 9.6, 10.0).\n",
"3. Standard activation patching would show that layer 9 matters \u2014 but cannot distinguish *which* upstream head is responsible for each name-mover head's query input.\n",
"\n",
"Direct path patching adds that resolution, isolating the A \u2192 B causal edge without affecting any other component's view of A's output."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading