Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 236 additions & 5 deletions demos/Exploratory_Analysis_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Detect if we're running in Google Colab\n",
"try:\n",
Expand All @@ -86,7 +87,9 @@
" ip = get_ipython()\n",
" if not ip.extension_manager.loaded:\n",
" ip.extension_manager.load('autoreload')\n",
" %autoreload 2\n"
" %autoreload 2\n",
"\n",
"IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\""
]
},
{
Expand Down Expand Up @@ -164,13 +167,16 @@
"metadata": {},
"outputs": [],
"source": [
"def imshow(tensor, **kwargs):\n",
" px.imshow(\n",
"def imshow(tensor, show=True, **kwargs):\n",
" fig = px.imshow(\n",
" utils.to_numpy(tensor),\n",
" color_continuous_midpoint=0.0,\n",
" color_continuous_scale=\"RdBu\",\n",
" **kwargs,\n",
" ).show()\n",
" )\n",
" if show:\n",
" fig.show()\n",
" return fig\n",
"\n",
"\n",
"def line(tensor, **kwargs):\n",
Expand Down Expand Up @@ -253,6 +259,7 @@
" fold_ln=True,\n",
" refactor_factored_attn_matrices=True,\n",
")\n",
"model.set_use_attn_result(True)\n",
"\n",
"# Get the default device used\n",
"device: torch.device = utils.get_device()"
Expand Down Expand Up @@ -1387,6 +1394,230 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Direct Path Patching\n",
"\n",
"Decomposing attention heads helped us understand whether their outputs are affected more by their computed attention pattern or values, but we've looked at each head in isolation. We can try taking this one step further by attempting to find how one attention head's output affects another's attention pattern and values. This could start revealing the circuitry of how information is moved across layers!\n",
"\n",
"One mechanism we can use is Direct Path Patching. Path patching is like activation patching, but rather than patching an activation we patch the *effect* of one activation on another, later activation. *Direct* path patching, specifically, means patching the linear part of that effect, i.e., the part that passes through the residual stream from the first activation to the later one, without going through attention heads or MLPs (see this [ARENA notebook](https://colab.research.google.com/drive/1KgrEwvCKdX-8DQ1uSiIuxwIiwzJuQ3Gw#scrollTo=b32Qdk-Gl6mU) for an alternative definition which does include MLPs).\n",
"\n",
"For example, to do direct path patching from the output of attention head `A` to the query of attention head `B` in a later layer `L`, we'd add a hook saying:\n",
"```py\n",
"patched_B_query = corrupted_B_query + (clean_A_output - corrupted_A_output) @ W_Q(B) / corrupted_ln1_scale(L)\n",
"```\n",
"\n",
"<details>\n",
"<summary>Why the corrupted ln1?</summary>\n",
"Since all earlier components in the run are corrupt, we apply the cached layer norm from the corrupted run rather than that of the clean one to get a better approximation of the linear effect (see \"Ignoring LayerNorm\" above for why dividing by the cached LayerNorm makes sense in the first place).\n",
"</details>\n",
"<br/>\n",
"\n",
"We'll look at the effects of direct path patching from outputs of attention heads to query, key and value vectors of following attention heads. As before, we patch the effects across all token positions, although it's possible to further zoom in to specific token positions (notably the last one and the second subject one).\n",
"\n",
"<details>\n",
"<summary>More on token positions</summary>\n",
"Note that direct path patching can only meaningfully be done for activations in the same token position; any path between two activations in two different token positions must go through attention, hence it is not a direct path. Still, patching a direct path from the output of an attention head in a specific token position to either the key or value vectors of a later attention head attending to the same token position could affect that head's output in other querying token positions.\n",
"</details>\n",
"<br/>\n",
"\n",
"First, the implementation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def patch_direct_path_A_output_to_B_attn_act(\n",
" corrupted_attn_act: Float[torch.Tensor, \"batch pos head_index d_head\"],\n",
" hook,\n",
" attn_act_type,\n",
" A_layer,\n",
" A_head_index,\n",
" B_layer,\n",
" B_head_index,\n",
" clean_cache,\n",
" corrupted_cache,\n",
"):\n",
" A_output_act_name = utils.get_act_name(\"result\", A_layer, \"attn\")\n",
" clean_A_output = clean_cache[A_output_act_name][:, :, A_head_index, :]\n",
" corrupted_A_output = corrupted_cache[A_output_act_name][:, :, A_head_index, :]\n",
" corrupted_B_act = corrupted_attn_act[:, :, B_head_index, :]\n",
" W_B = _get_attention_weights(model.blocks[B_layer].attn, attn_act_type)[B_head_index]\n",
" corrupted_ln1_scale = corrupted_cache[utils.get_act_name(\"scale\", B_layer, \"ln1\")]\n",
"\n",
" patched_B_act = corrupted_B_act + (clean_A_output - corrupted_A_output) @ W_B / corrupted_ln1_scale\n",
"\n",
" patched_attn_act = corrupted_attn_act.clone()\n",
" patched_attn_act[:, :, B_head_index, :] = patched_B_act\n",
" return patched_attn_act\n",
"\n",
"\n",
"def direct_path_output_to_attn_act_diffs(output_layer, attn_act_layer, attn_act_type):\n",
" A_to_B_diffs = torch.zeros(model.cfg.n_heads, model.cfg.n_heads, device=device, dtype=torch.float32)\n",
" for A in range(model.cfg.n_heads):\n",
" for B in range(model.cfg.n_heads):\n",
" patch_hook = (\n",
" utils.get_act_name(attn_act_type, attn_act_layer, \"attn\"),\n",
" partial(\n",
" patch_direct_path_A_output_to_B_attn_act,\n",
" attn_act_type=attn_act_type,\n",
" A_layer=output_layer,\n",
" A_head_index=A,\n",
" B_layer=attn_act_layer,\n",
" B_head_index=B,\n",
" clean_cache=cache,\n",
" corrupted_cache=corrupted_cache,\n",
" ),\n",
" )\n",
" # It would have been great to set start_layer=B and the inputs to the cached corrupted residual stream at that layer;\n",
" # unfortunately, this is not supported by TransformerBridge. Legacy HookedTransformer supports it but is slower, so\n",
" # it doesn't pay off to switch.\n",
" patched_logits = model.run_with_hooks(\n",
" corrupted_tokens,\n",
" fwd_hooks=[patch_hook],\n",
" return_type=\"logits\",\n",
" )\n",
" patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)\n",
" A_to_B_diffs[A, B] = normalize_patched_logit_diff(patched_logit_diff)\n",
" return A_to_B_diffs\n",
"\n",
"\n",
"def _get_attention_weights(attn, attn_act_type):\n",
" if attn_act_type == \"q\":\n",
" return attn.W_Q\n",
" if attn_act_type == \"k\":\n",
" return attn.W_K\n",
" if attn_act_type == \"v\":\n",
" return attn.W_V\n",
" raise ValueError"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def show_direct_path_output_to_attn_act_diffs(diffs, output_layer, attn_act_layer, attn_act_type, zrange=0.3):\n",
" attn_act_name = _get_attention_act_type_name(attn_act_type)\n",
" fig = imshow(\n",
" diffs,\n",
" show=False,\n",
" title=f\"Direct Path Patch Logit Difference, L{output_layer} Output to L{attn_act_layer} {attn_act_name}\",\n",
" labels={\"x\": f\"{attn_act_name} Head\", \"y\": \"Output Head\"},\n",
" zmin=-zrange,\n",
" zmax=zrange,\n",
" )\n",
" fig.update_traces(\n",
" hovertemplate=f\"L{output_layer}H%{{y}} -> L{attn_act_layer}H%{{x}}<br>%{{z:.3f}}<extra></extra>\"\n",
" )\n",
" fig.show()\n",
"\n",
"\n",
"def _get_attention_act_type_name(attn_act_type):\n",
" if attn_act_type == \"q\":\n",
" return \"Query\"\n",
" if attn_act_type == \"k\":\n",
" return \"Key\"\n",
" if attn_act_type == \"v\":\n",
" return \"Value\"\n",
" raise ValueError"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Testing that it works:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_direct_path_patching():\n",
" output_layer, value_layer = (5, 8)\n",
" diffs = direct_path_output_to_attn_act_diffs(output_layer, value_layer, \"v\")\n",
" show_direct_path_output_to_attn_act_diffs(diffs, output_layer, value_layer, \"v\")\n",
"\n",
"\n",
"test_direct_path_patching()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that L5H5's output directly affects L8H6's attention value vector. We'd previously seen that the output of either head is significant to correct prediction, but now we can further conclude that they perform (at least some of) their function together rather than in isolation!\n",
"\n",
"Moving on to queries: The previous section showed how patching the attention patterns of late-layer heads affects their outputs significantly. We could hypothesize that the computed query vectors of these heads are different in the corrupted vs. clean runs because of information that the mid-layer heads are generating. If that's the case and the effect is direct (passes through the residual stream), then direct path patching might find it! The time to iterate pairs of layers increases quadratically with the number of layers, so with this hypothesis in mind we can focus on just the mid-late layers range, and we can also limit looking up to two layers ahead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not IN_GITHUB:\n",
" for output_layer in range(7, model.cfg.n_layers - 1):\n",
" for query_layer in range(output_layer + 1, min(output_layer + 3, model.cfg.n_layers)):\n",
" diffs = direct_path_output_to_attn_act_diffs(output_layer, query_layer, \"q\")\n",
" show_direct_path_output_to_attn_act_diffs(diffs, output_layer, query_layer, \"q\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Two diffs that stand out are that patching L8H6's output direct path to L9H9's query improves next token prediction, whereas patching L9H9's output direct paths to both L10H7's and L11H10's queries worsens it. Previously, we'd already seen that patching L8H6 and L9H9 outputs improves prediction and that patching L10H7 and L11H10 outputs worsens it, and that the computed attention patterns for L9H9, L10H7 and L11H10 play a significant role; we can now further conclude that L8H6's output directly affects L9H9's attention pattern significantly, and that L9H9's output directly (negatively) affects L10H7's and L11H10's attention patterns significantly, via the computed query vectors.\n",
"\n",
"What about keys? We can guess that outputs of earlier attention heads set up keys on token positions to determine what later heads attend to:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not IN_GITHUB:\n",
" for output_layer in range(5, 8):\n",
" for key_layer in range(8, model.cfg.n_layers):\n",
" diffs = direct_path_output_to_attn_act_diffs(output_layer, key_layer, \"k\")\n",
" show_direct_path_output_to_attn_act_diffs(diffs, output_layer, key_layer, \"k\", zrange=0.03)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We do find direct effects of outputs on keys, albeit at a much lower scale than on queries and values. The most significant ones are from L6H9 and L7H1 on L8H6.\n",
"\n",
"Note that it's possible stronger effects of outputs on keys exist which are not revealed, even on the layers that we did test: there might be some indirect effects, or possibly additional direct effects which are masked by some other model behavior, e.g., a later corrupted component could remove such an effect.\n",
"\n",
"Combining the results of the above experiments, we get an outline of some circuitry:\n",
"\n",
"```\n",
" 5.5 6.9,7.1\n",
" | |\n",
" v | | k\n",
" \\ /\n",
" 8.6 --> 9.9 --> 10.7,11.10\n",
" q q\n",
"```\n",
"\n",
"cf. the paper diagram below."
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -1469,7 +1700,7 @@
"Breaking down their categories:\n",
"\n",
"* Early: The duplicate token heads, previous token heads and induction heads. These serve the purpose of detecting that the second subject is duplicated and which earlier name is the duplicate.\n",
" * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H0 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n",
" * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H5 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!\n",
" * Note that the duplicate token heads and induction heads do not compose with each other - both directly add to the S-Inhibition heads. The diagram is somewhat misleading.\n",
"* Middle: They call these S-Inhibition heads - they copy the information about the duplicate token from the second subject to the to token, and their output is used to *inhibit* the attention paid from the name movers to the first subject copy. We found all these heads, and had a decent guess for what they did.\n",
" * In either case they attend to the second subject, so the patch that mattered was their value vectors!\n",
Expand Down
Loading