Skip to content

feat: add get_act_patch_direct_path for head-to-head circuit analysis#1396

Open
mukund1985 wants to merge 9 commits into
TransformerLensOrg:devfrom
mukund1985:feat/direct-path-patching
Open

feat: add get_act_patch_direct_path for head-to-head circuit analysis#1396
mukund1985 wants to merge 9 commits into
TransformerLensOrg:devfrom
mukund1985:feat/direct-path-patching

Conversation

@mukund1985

@mukund1985 mukund1985 commented Jun 17, 2026

Copy link
Copy Markdown

Summary

Closes #111 — implements direct path patching, as described by @neelnanda-io in that issue.

Standard activation patching replaces the full residual stream at a layer, affecting all downstream components simultaneously. Direct path patching isolates a single causal edge: it patches only the contribution of source head A into the query (or key/value) input of destination head B, leaving every other component's view unchanged.

Implementation

For source head A = (src_layer, src_head) and every downstream head B = (dst_layer, dst_head):

delta_resid  = clean_A_result - corrupted_A_result    # [batch, pos, d_model]
delta_B_q    = (delta_resid / ln1_scale) @ W_Q[hb]   # [batch, pos, d_head]
patched_B_q  = corrupted_B_q + delta_B_q

Per-head residual contribution computed from hook_z @ W_O — works with default run_with_cache(), no cfg.use_hook_result=True needed.

New files

  • transformer_lens/direct_path_patching.pyget_act_patch_direct_path() and get_act_patch_direct_path_all_sources(), matching the API of existing get_act_patch_* functions
  • tests/unit/test_direct_path_patching.py — 12 tests (shape, causal structure, manual correctness, edge cases)
  • demos/direct_path_patching_ioi.py — validated on GPT-2 small / IOI task

Results (GPT-2 small, IOI task)

Source Destination Normalised score
(8,6) (9,9) name mover +0.083
(8,10) (9,9) name mover +0.066
(8,10) (10,0) name mover +0.047
(7,9) (9,9) name mover +0.036

S-inhibition heads' strongest direct paths land on the known name-mover heads, confirming the Wang et al. 2022 IOI circuit.


Type of change

  • New feature (non-breaking change which adds functionality)

Checklist

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@jlarson4

Copy link
Copy Markdown
Collaborator

Hi @mukund1985! Thanks for putting this together. I have a collection of general notes below.

  1. We intend to deprecate the HookedTransformer tool in the next major release of TransformerLens (v4) which we are planning to release later this summer. Would it be possible for you to migrate this solution to our new TransformerBridge system? Based an initial surface level review, this should be possible. If you don't have time to do that migration yourself, just let me know and I can open an issue for someone in the community to take this part on.
  2. All of our demos are Python Notebooks, could you please refactor your demos/direct_path_patching_ioi.py to be an ipynb file, and include some markdown descriptions outlining the process, similar to our other notebooks? Any of the other recent notebook additions will make good examples.
  3. The method get_act_patch_direct_path requires a fold_ln=True model. Your test model does not fold LN due to the configuration, which causes inconsistencies in the test method's accuracy. Please resolve that on the test suite, and add a guard to the method to ensure that this issue does not reoccur.

I will also provide some additional specific feedback on some of the tests, but I wanted to get you some initial feedback first

Comment thread tests/unit/test_direct_path_patching.py
Comment thread tests/unit/test_direct_path_patching.py Outdated
@danra danra mentioned this pull request Jun 18, 2026
6 tasks
@mukund1985

Copy link
Copy Markdown
Author

Thanks for the detailed review @jlarson4! All three points addressed in the latest commit:

  1. TransformerBridge support — both get_act_patch_direct_path and get_act_patch_direct_path_all_sources now accept Union[HookedTransformer, TransformerBridge]. Added a _check_fold_ln() guard that inspects blocks[0].ln1.w and emits a UserWarning if LN weights have not been folded in; it silently passes for TransformerBridge and other model types where the check is not applicable.

  2. Test fixes:

    • Added model.process_weights_() before model.eval() in the tiny_model fixture to fold LN into the weight matrices.
    • Replaced the circular test_manual_patch_matches_function with test_correctness_against_actual_ln_forward, which builds an independent reference by running the actual ln1 forward pass on the patched residual stream (rather than replicating the same linear formula), then asserts the function's approximation agrees within atol=0.15.
  3. Notebook demodemos/direct_path_patching_ioi.py converted to demos/direct_path_patching_ioi.ipynb with a Colab badge, markdown descriptions for each section, and the same structure as the other recent notebooks.

Happy to iterate further on any of these.

@mukund1985

mukund1985 commented Jun 18, 2026

Copy link
Copy Markdown
Author

@jlarson4 — quick update on everything addressed so far:

General notes:

  1. TransformerBridge migration — both public functions now accept Union[HookedTransformer, TransformerBridge]. The hook API (run_with_hooks, run_with_cache, blocks[i].attn.W_Q/W_O) is identical between the two systems so no patching logic needed to change.
  2. Notebook demodemos/direct_path_patching_ioi.ipynb with a Colab badge and markdown descriptions for each section.
  3. fold_ln guard + test fix_check_fold_ln() warns if LN is not folded; tiny_model fixture calls model.process_weights_() before model.eval().

Inline comments:

  • process_weights_() called before model.eval() in the fixture (line 40).
  • Circular test replaced with test_correctness_against_actual_ln_forward — uses the actual ln1 forward pass as an independent reference rather than replicating the same linear formula.

One additional fix (commit ba3acee):
_check_fold_ln was previously silently skipping for TransformerBridge because its wrapped HuggingFace LayerNorm stores the scale as .weight rather than .w. It now falls back to .weight, so the guard fires correctly for both model systems.

Looking forward to your additional test feedback whenever you have a chance.

@mukund1985

Copy link
Copy Markdown
Author

Noticed #1398 just opened covering similar ground. Worth flagging the distinction: that PR adds a demo section to the existing Exploratory Analysis notebook but does not add any library code. This PR is the actual implementation — transformer_lens/direct_path_patching.py, the test suite, and the __init__ export — so the two are complementary rather than overlapping. Happy to coordinate on the demo side if it makes sense to consolidate, but the core feature lives here.

@jlarson4 jlarson4 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mukund1985, thanks for the quick turn around on these updates. I have a couple more small guidance comments attached, but once those are wrapped up this should be ready to merge.

I agree with your assessment that your PR and @danra's #1398 are complimentary, rather than conflicting. I will reach out to him on his PR with instructions on how to ensure we get everything properly integrated.

Comment thread tests/unit/test_direct_path_patching.py Outdated
corrupted_tokens,
fwd_hooks=[(f"blocks.{dst_layer}.attn.hook_q", true_hook)],
)
ref_metric = simple_metric(ref_logits).item()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a logit difference rather than simple_metric's summing of last token logits, e.g. lambda lg: lg[0,-1,correct] - lg[0,-1,incorrect]. It's invariant to the center_unembed that process_weights_() applies (a difference cancels the centering offset, whereas a sum hits the zero-invariant it creates), so the correctness test gets real signal based on my tests. It's also the standard path-patching metric, so the IOI demo should use the correct-vs-incorrect-answer version. With stronger signal in the metric you should be able to lower atol from 0.15 to ~1e-3.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danra's #1398 actually includes an example of this here, using logits_to_ave_logit_diff

try:
ln1 = model.blocks[0].ln1 # type: ignore[index]
# .w → HookedTransformer; .weight → TransformerBridge (wraps HF module)
w = getattr(ln1, "w", None) or getattr(ln1, "weight", None)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test that covers the bridge implementation as well, to ensure this fallback works as expected?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An overall repo-organization note: this branch is based on a slightly older version of dev. If you rebase to the latest dev, you'll see we've added a section precisely for files like this under transformer_lens/tools/analysis/. I would love for this file to live there alongside the Direct Logit Attribution tool.

@mukund1985

Copy link
Copy Markdown
Author

@jlarson4 all three points addressed — pushed to a fresh branch rebased on latest dev:

  1. Logit diff metrictest_correctness_against_actual_ln_forward now uses logits[0,-1,correct_tok] - logits[0,-1,incorrect_tok] as the metric. This cancels the centering offset from process_weights_() and tightens the tolerance from atol=0.15 to atol=1e-3.

  2. _check_fold_ln bridge test — added TestCheckFoldLn with 5 tests covering the .w (HookedTransformer) path, the .weight (TransformerBridge) path, and the no-crash case for non-standard models. Also fixed a latent bug where getattr(...) or getattr(...) on multi-element tensors raised RuntimeError: Boolean value of Tensor with more than one value is ambiguous — replaced with explicit None checks.

  3. File location — moved direct_path_patching.py to transformer_lens/tools/analysis/ alongside the Direct Logit Attribution tool. Updated exports in tools/analysis/__init__.py. Rebased cleanly on latest dev (branch feat/direct-path-patching-v2).

@jlarson4

Copy link
Copy Markdown
Collaborator

@mukund1985 I am not seeing any new updates to this PR's code, can you double check the push for me?

@mukund1985

Copy link
Copy Markdown
Author

@jlarson4 sorry about that — the push did not land last time. Should be up now (commit ab19354). Three things in this commit:

  1. File location — moved direct_path_patching.py to transformer_lens/tools/analysis/ alongside the Direct Logit Attribution tool. Added tools/analysis/__init__.py exporting both public functions. Updated transformer_lens/__init__.py to import tools instead of the old top-level module.

  2. Logit-diff metrictest_correctness_against_actual_ln_forward now uses logits[0,-1,correct_tok] - logits[0,-1,incorrect_tok]. This cancels the centering offset from process_weights_() and tightens the tolerance from atol=0.15 to atol=1e-3.

  3. TestCheckFoldLn (5 tests) — covers the pre-fold .w path, the non-unit scale warning, no-crash on missing attribute, and a regression check for the getattr(...) or getattr(...) multi-element tensor bug (replaced with explicit None checks). Also clarified in the test comments that process_weights_() replaces LayerNorm with LayerNormPre (no .w), which is why the folded model passes silently.

All 17 tests pass.

Closes TransformerLensOrg#111.

Implements direct path patching — a finer-grained variant of activation
patching that isolates the direct information flow between two specific
attention heads, rather than replacing the full residual stream.

Why
---
Standard activation patching tells you that *some* component at layer L
matters, but it cannot distinguish whether head B at layer L+2 matters
because it received information directly from head A, or because A's
output propagated through many intermediate components first. Direct path
patching isolates the A → B causal edge precisely.

Implementation
--------------
For a fixed source head A = (src_layer, src_head) and every downstream
destination head B = (dst_layer, dst_head):

  delta_resid  = clean_A_result - corrupted_A_result    # [batch, pos, d_model]
  delta_B_q    = (delta_resid / ln1_scale) @ W_Q[hb]  # [batch, pos, d_head]
  patched_B_q  = corrupted_B_q + delta_B_q

The per-head residual contribution is computed from hook_z @ W_O (always
available in the default cache) rather than hook_result, which requires
the non-default cfg.use_hook_result=True flag.

New files
---------
- transformer_lens/direct_path_patching.py
    get_act_patch_direct_path()          [n_layers, n_heads] sweep
    get_act_patch_direct_path_all_sources()  [n_layers, n_heads, n_layers, n_heads] full sweep

- tests/unit/test_direct_path_patching.py
    12 tests covering output shape, causal structure, manual correctness
    verification, and edge cases. All pass on a tiny randomly-initialised
    3-layer model (no downloads, runs in ~3s on CPU).

- demos/direct_path_patching_ioi.py
    Validated on GPT-2 small / IOI task. S-inhibition heads (7.3, 7.9,
    8.6, 8.10) show strongest direct paths into name-mover heads (9.9,
    9.6, 10.0), confirming the Wang et al. 2022 IOI circuit.

    (8,6) → (9,9):  +0.083 normalised logit diff
    (8,10) → (9,9): +0.066
    (7,9) → (9,9):  +0.036

API matches existing get_act_patch_* functions in patching.py for
drop-in use alongside the existing circuit analysis toolkit.
TransformerBridge wraps the original HuggingFace LayerNorm module, which
stores the learned scale as .weight rather than the .w used by
HookedTransformer.  Fall back to .weight so the guard actually fires when
a TransformerBridge model is passed without folded LN, rather than silently
skipping the check.
…, fix _check_fold_ln tensor bug

- Move direct_path_patching.py to transformer_lens/tools/analysis/ alongside
  the Direct Logit Attribution tool; add tools/analysis/__init__.py exporting
  both public functions; update transformer_lens/__init__.py accordingly.

- Fix _check_fold_ln: replace 'getattr(...) or getattr(...)' with explicit
  None checks to avoid RuntimeError on multi-element tensors.

- test_correctness_against_actual_ln_forward: switch patching metric to
  logit diff (correct_tok - incorrect_tok), which cancels the centering
  offset introduced by process_weights_() and tightens tolerance 0.15 -> 1e-3.

- Add TestCheckFoldLn (5 tests): folded model no-warning, unfolded model
  warns, pre-fold .w attribute present, no crash on missing attribute,
  no RuntimeError on multi-element tensor regression check.

All 17 tests pass.
@mukund1985 mukund1985 force-pushed the feat/direct-path-patching branch from ab19354 to 0b68b94 Compare June 19, 2026 04:44
_check_fold_ln is a private defensive helper with a try/except that
handles arbitrary model types. The Union[HookedTransformer, TransformerBridge]
annotation was causing beartype to reject valid test fixtures (and any
non-standard model) at the call boundary before the function's own
exception handling could run. Any is the correct annotation for a
function intentionally designed to tolerate unknown model shapes.
@mukund1985

Copy link
Copy Markdown
Author

@jlarson4 rebased onto the latest dev after #1398 merged. One conflict in tools/analysis/__init__.py — upstream had independently added DirectLogitAttribution there, so I merged both sides: the existing DLA exports are preserved alongside the direct path patching exports.

Also fixed a test failure surfaced by the rebase: test_no_crash_on_missing_attribute was being rejected by beartype at the call boundary before _check_fold_ln's own exception handling could run. Changed the type hint from Union[HookedTransformer, TransformerBridge] to Any — correct since the function is explicitly designed to tolerate arbitrary model shapes. All 17 tests pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants