feat: add get_act_patch_direct_path for head-to-head circuit analysis#1396
feat: add get_act_patch_direct_path for head-to-head circuit analysis#1396mukund1985 wants to merge 9 commits into
Conversation
|
Hi @mukund1985! Thanks for putting this together. I have a collection of general notes below.
I will also provide some additional specific feedback on some of the tests, but I wanted to get you some initial feedback first |
|
Thanks for the detailed review @jlarson4! All three points addressed in the latest commit:
Happy to iterate further on any of these. |
|
@jlarson4 — quick update on everything addressed so far: General notes:
Inline comments:
One additional fix (commit Looking forward to your additional test feedback whenever you have a chance. |
|
Noticed #1398 just opened covering similar ground. Worth flagging the distinction: that PR adds a demo section to the existing Exploratory Analysis notebook but does not add any library code. This PR is the actual implementation — |
jlarson4
left a comment
There was a problem hiding this comment.
Hi @mukund1985, thanks for the quick turn around on these updates. I have a couple more small guidance comments attached, but once those are wrapped up this should be ready to merge.
I agree with your assessment that your PR and @danra's #1398 are complimentary, rather than conflicting. I will reach out to him on his PR with instructions on how to ensure we get everything properly integrated.
| corrupted_tokens, | ||
| fwd_hooks=[(f"blocks.{dst_layer}.attn.hook_q", true_hook)], | ||
| ) | ||
| ref_metric = simple_metric(ref_logits).item() |
There was a problem hiding this comment.
Use a logit difference rather than simple_metric's summing of last token logits, e.g. lambda lg: lg[0,-1,correct] - lg[0,-1,incorrect]. It's invariant to the center_unembed that process_weights_() applies (a difference cancels the centering offset, whereas a sum hits the zero-invariant it creates), so the correctness test gets real signal based on my tests. It's also the standard path-patching metric, so the IOI demo should use the correct-vs-incorrect-answer version. With stronger signal in the metric you should be able to lower atol from 0.15 to ~1e-3.
| try: | ||
| ln1 = model.blocks[0].ln1 # type: ignore[index] | ||
| # .w → HookedTransformer; .weight → TransformerBridge (wraps HF module) | ||
| w = getattr(ln1, "w", None) or getattr(ln1, "weight", None) |
There was a problem hiding this comment.
Can we add a test that covers the bridge implementation as well, to ensure this fallback works as expected?
There was a problem hiding this comment.
An overall repo-organization note: this branch is based on a slightly older version of dev. If you rebase to the latest dev, you'll see we've added a section precisely for files like this under transformer_lens/tools/analysis/. I would love for this file to live there alongside the Direct Logit Attribution tool.
|
@jlarson4 all three points addressed — pushed to a fresh branch rebased on latest dev:
|
|
@mukund1985 I am not seeing any new updates to this PR's code, can you double check the push for me? |
|
@jlarson4 sorry about that — the push did not land last time. Should be up now (commit
All 17 tests pass. |
Closes TransformerLensOrg#111. Implements direct path patching — a finer-grained variant of activation patching that isolates the direct information flow between two specific attention heads, rather than replacing the full residual stream. Why --- Standard activation patching tells you that *some* component at layer L matters, but it cannot distinguish whether head B at layer L+2 matters because it received information directly from head A, or because A's output propagated through many intermediate components first. Direct path patching isolates the A → B causal edge precisely. Implementation -------------- For a fixed source head A = (src_layer, src_head) and every downstream destination head B = (dst_layer, dst_head): delta_resid = clean_A_result - corrupted_A_result # [batch, pos, d_model] delta_B_q = (delta_resid / ln1_scale) @ W_Q[hb] # [batch, pos, d_head] patched_B_q = corrupted_B_q + delta_B_q The per-head residual contribution is computed from hook_z @ W_O (always available in the default cache) rather than hook_result, which requires the non-default cfg.use_hook_result=True flag. New files --------- - transformer_lens/direct_path_patching.py get_act_patch_direct_path() [n_layers, n_heads] sweep get_act_patch_direct_path_all_sources() [n_layers, n_heads, n_layers, n_heads] full sweep - tests/unit/test_direct_path_patching.py 12 tests covering output shape, causal structure, manual correctness verification, and edge cases. All pass on a tiny randomly-initialised 3-layer model (no downloads, runs in ~3s on CPU). - demos/direct_path_patching_ioi.py Validated on GPT-2 small / IOI task. S-inhibition heads (7.3, 7.9, 8.6, 8.10) show strongest direct paths into name-mover heads (9.9, 9.6, 10.0), confirming the Wang et al. 2022 IOI circuit. (8,6) → (9,9): +0.083 normalised logit diff (8,10) → (9,9): +0.066 (7,9) → (9,9): +0.036 API matches existing get_act_patch_* functions in patching.py for drop-in use alongside the existing circuit analysis toolkit.
…uard, independent test, notebook demo
TransformerBridge wraps the original HuggingFace LayerNorm module, which stores the learned scale as .weight rather than the .w used by HookedTransformer. Fall back to .weight so the guard actually fires when a TransformerBridge model is passed without folded LN, rather than silently skipping the check.
…, fix _check_fold_ln tensor bug - Move direct_path_patching.py to transformer_lens/tools/analysis/ alongside the Direct Logit Attribution tool; add tools/analysis/__init__.py exporting both public functions; update transformer_lens/__init__.py accordingly. - Fix _check_fold_ln: replace 'getattr(...) or getattr(...)' with explicit None checks to avoid RuntimeError on multi-element tensors. - test_correctness_against_actual_ln_forward: switch patching metric to logit diff (correct_tok - incorrect_tok), which cancels the centering offset introduced by process_weights_() and tightens tolerance 0.15 -> 1e-3. - Add TestCheckFoldLn (5 tests): folded model no-warning, unfolded model warns, pre-fold .w attribute present, no crash on missing attribute, no RuntimeError on multi-element tensor regression check. All 17 tests pass.
ab19354 to
0b68b94
Compare
_check_fold_ln is a private defensive helper with a try/except that handles arbitrary model types. The Union[HookedTransformer, TransformerBridge] annotation was causing beartype to reject valid test fixtures (and any non-standard model) at the call boundary before the function's own exception handling could run. Any is the correct annotation for a function intentionally designed to tolerate unknown model shapes.
|
@jlarson4 rebased onto the latest Also fixed a test failure surfaced by the rebase: |
Summary
Closes #111 — implements direct path patching, as described by @neelnanda-io in that issue.
Standard activation patching replaces the full residual stream at a layer, affecting all downstream components simultaneously. Direct path patching isolates a single causal edge: it patches only the contribution of source head A into the query (or key/value) input of destination head B, leaving every other component's view unchanged.
Implementation
For source head A =
(src_layer, src_head)and every downstream head B =(dst_layer, dst_head):Per-head residual contribution computed from
hook_z @ W_O— works with defaultrun_with_cache(), nocfg.use_hook_result=Trueneeded.New files
transformer_lens/direct_path_patching.py—get_act_patch_direct_path()andget_act_patch_direct_path_all_sources(), matching the API of existingget_act_patch_*functionstests/unit/test_direct_path_patching.py— 12 tests (shape, causal structure, manual correctness, edge cases)demos/direct_path_patching_ioi.py— validated on GPT-2 small / IOI taskResults (GPT-2 small, IOI task)
S-inhibition heads' strongest direct paths land on the known name-mover heads, confirming the Wang et al. 2022 IOI circuit.
Type of change
Checklist