Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions transformer_lens/benchmarks/backward_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
make_grad_capture_hook,
safe_allclose,
)
from transformer_lens.hook_points import HookPoint
from transformer_lens.model_bridge import TransformerBridge


Expand Down Expand Up @@ -44,22 +45,24 @@ def benchmark_backward_hooks(
hook_names = list(bridge._hook_registry.keys())

# Register backward hooks on bridge
bridge_handles = []
bridge_hook_points: list[HookPoint] = []
for hook_name in hook_names:
if hook_name in bridge.hook_dict:
hook_point = bridge.hook_dict[hook_name]
handle = hook_point.add_hook(make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
bridge_handles.append(handle)
hook_point.add_hook(
make_grad_capture_hook(bridge_gradients, hook_name, return_none=True),
dir="bwd",
)
bridge_hook_points.append(hook_point)

# Run bridge forward and backward
bridge_output = bridge(test_text)
bridge_loss = bridge_output[:, -1, :].sum()
bridge_loss.backward()

# Clean up hooks
for handle in bridge_handles:
if handle is not None:
handle.remove()
for hook_point in bridge_hook_points:
hook_point.remove_hooks(dir="bwd")

if reference_model is None:
# No reference - just verify gradients were captured
Expand All @@ -77,22 +80,24 @@ def benchmark_backward_hooks(
return result

# Register backward hooks on reference model
reference_handles = []
reference_hook_points: list[HookPoint] = []
for hook_name in hook_names:
if hook_name in reference_model.hook_dict:
hook_point = reference_model.hook_dict[hook_name]
handle = hook_point.add_hook(make_grad_capture_hook(reference_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
reference_handles.append(handle)
hook_point.add_hook(
make_grad_capture_hook(reference_gradients, hook_name, return_none=True),
dir="bwd",
)
reference_hook_points.append(hook_point)

# Run reference forward and backward
reference_output = reference_model(test_text)
reference_loss = reference_output[:, -1, :].sum()
reference_loss.backward()

# Clean up hooks
for handle in reference_handles:
if handle is not None:
handle.remove()
for hook_point in reference_hook_points:
hook_point.remove_hooks(dir="bwd")

# Compare gradients
common_hooks = set(bridge_gradients.keys()) & set(reference_gradients.keys())
Expand Down Expand Up @@ -295,22 +300,24 @@ def benchmark_critical_backward_hooks(
bridge_gradients: Dict[str, torch.Tensor] = {}

# Register backward hooks on bridge
bridge_handles = []
bridge_hook_points: list[HookPoint] = []
for hook_name in critical_hooks:
if hook_name in bridge.hook_dict:
hook_point = bridge.hook_dict[hook_name]
handle = hook_point.add_hook(make_grad_capture_hook(bridge_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
bridge_handles.append(handle)
hook_point.add_hook(
make_grad_capture_hook(bridge_gradients, hook_name, return_none=True),
dir="bwd",
)
bridge_hook_points.append(hook_point)

# Run bridge forward and backward
bridge_output = bridge(test_text)
bridge_loss = bridge_output[:, -1, :].sum()
bridge_loss.backward()

# Clean up hooks
for handle in bridge_handles:
if handle is not None:
handle.remove()
for hook_point in bridge_hook_points:
hook_point.remove_hooks(dir="bwd")

if reference_model is None:
# No reference - just verify gradients were captured
Expand All @@ -331,22 +338,24 @@ def benchmark_critical_backward_hooks(
# Register backward hooks on reference model
reference_gradients: Dict[str, torch.Tensor] = {}

reference_handles = []
reference_hook_points: list[HookPoint] = []
for hook_name in critical_hooks:
if hook_name in reference_model.hook_dict:
hook_point = reference_model.hook_dict[hook_name]
handle = hook_point.add_hook(make_grad_capture_hook(reference_gradients, hook_name, return_none=True), dir="bwd") # type: ignore[func-returns-value]
reference_handles.append(handle)
hook_point.add_hook(
make_grad_capture_hook(reference_gradients, hook_name, return_none=True),
dir="bwd",
)
reference_hook_points.append(hook_point)

# Run reference forward and backward
reference_output = reference_model(test_text)
reference_loss = reference_output[:, -1, :].sum()
reference_loss.backward()

# Clean up hooks
for handle in reference_handles:
if handle is not None:
handle.remove()
for hook_point in reference_hook_points:
hook_point.remove_hooks(dir="bwd")

# Compare gradients
mismatches = []
Expand Down
58 changes: 27 additions & 31 deletions transformer_lens/benchmarks/hook_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
filter_expected_missing_hooks,
make_capture_hook,
)
from transformer_lens.hook_points import HookPoint
from transformer_lens.model_bridge import TransformerBridge


Expand Down Expand Up @@ -134,13 +135,13 @@ def benchmark_forward_hooks(
hook_names = list(bridge.hook_dict.keys())

# Register hooks on bridge and track missing hooks
bridge_handles = []
bridge_hook_points: list[tuple[str, HookPoint]] = []
missing_from_bridge = []
for hook_name in hook_names:
if hook_name in bridge.hook_dict:
hook_point = bridge.hook_dict[hook_name]
handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value]
bridge_handles.append((hook_name, handle))
hook_point.add_hook(make_capture_hook(bridge_activations, hook_name))
bridge_hook_points.append((hook_name, hook_point))
else:
missing_from_bridge.append(hook_name)

Expand All @@ -152,12 +153,11 @@ def benchmark_forward_hooks(
_ = bridge(test_text)

# Clean up bridge hooks
for hook_name, handle in bridge_handles:
if handle is not None:
handle.remove()
for _, hook_point in bridge_hook_points:
hook_point.remove_hooks()

# Check for hooks that didn't fire (registered but no activation captured)
registered_hooks = {name for name, _ in bridge_handles}
registered_hooks = {name for name, _ in bridge_hook_points}
hooks_that_didnt_fire = registered_hooks - set(bridge_activations.keys())

if reference_model is None:
Expand All @@ -182,12 +182,12 @@ def benchmark_forward_hooks(
)

# Register hooks on reference model
reference_handles = []
reference_hook_points: list[HookPoint] = []
for hook_name in hook_names:
if hook_name in reference_model.hook_dict:
hook_point = reference_model.hook_dict[hook_name]
handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value]
reference_handles.append(handle)
hook_point.add_hook(make_capture_hook(reference_activations, hook_name))
reference_hook_points.append(hook_point)

# Run reference forward pass
with torch.no_grad():
Expand All @@ -197,9 +197,8 @@ def benchmark_forward_hooks(
_ = reference_model(test_text)

# Clean up reference hooks
for handle in reference_handles:
if handle is not None:
handle.remove()
for hook_point in reference_hook_points:
hook_point.remove_hooks()

# CRITICAL CHECK: Bridge must have all hooks that reference has.
# Filter out hooks that bridge models inherently don't have.
Expand Down Expand Up @@ -363,7 +362,7 @@ def benchmark_gated_hooks_fire(
tested_flags.append(flag_name)
try:
activations: dict[str, torch.Tensor] = {}
handles: list[tuple[str, object]] = []
bridge_hook_points: list[HookPoint] = []
target_hook_names = [
name
for name in bridge.hook_dict
Expand All @@ -375,18 +374,17 @@ def benchmark_gated_hooks_fire(
]
for hname in target_hook_names:
hp = bridge.hook_dict[hname]
h = hp.add_hook(make_capture_hook(activations, hname)) # type: ignore[func-returns-value]
handles.append((hname, h))
hp.add_hook(make_capture_hook(activations, hname))
bridge_hook_points.append(hp)

with torch.no_grad():
if prepend_bos is not None:
_ = bridge(test_text, prepend_bos=prepend_bos)
else:
_ = bridge(test_text)

for _, h in handles:
if h is not None and hasattr(h, "remove"):
h.remove()
for hp in bridge_hook_points:
hp.remove_hooks()

# Bucket fired counts per stem.
for stem in hook_stems:
Expand Down Expand Up @@ -497,21 +495,20 @@ def benchmark_critical_forward_hooks(
bridge_activations: Dict[str, torch.Tensor] = {}

# Register hooks on bridge
bridge_handles = []
bridge_hook_points: list[HookPoint] = []
for hook_name in critical_hooks:
if hook_name in bridge.hook_dict:
hook_point = bridge.hook_dict[hook_name]
handle = hook_point.add_hook(make_capture_hook(bridge_activations, hook_name)) # type: ignore[func-returns-value]
bridge_handles.append(handle)
hook_point.add_hook(make_capture_hook(bridge_activations, hook_name))
bridge_hook_points.append(hook_point)

# Run bridge forward pass
with torch.no_grad():
_ = bridge(test_text)

# Clean up hooks
for handle in bridge_handles:
if handle is not None:
handle.remove()
for hook_point in bridge_hook_points:
hook_point.remove_hooks()

if reference_model is None:
# No reference - just verify activations were captured
Expand All @@ -526,21 +523,20 @@ def benchmark_critical_forward_hooks(
# Compare with reference model
reference_activations: Dict[str, torch.Tensor] = {}

reference_handles = []
reference_hook_points: list[HookPoint] = []
for hook_name in critical_hooks:
if hook_name in reference_model.hook_dict:
hook_point = reference_model.hook_dict[hook_name]
handle = hook_point.add_hook(make_capture_hook(reference_activations, hook_name)) # type: ignore[func-returns-value]
reference_handles.append(handle)
hook_point.add_hook(make_capture_hook(reference_activations, hook_name))
reference_hook_points.append(hook_point)

# Run reference forward pass
with torch.no_grad():
_ = reference_model(test_text)

# Clean up hooks
for handle in reference_handles:
if handle is not None:
handle.remove()
for hook_point in reference_hook_points:
hook_point.remove_hooks()

# Compare activations — categorize by presence
bridge_missing = [] # Hooks in reference but not in bridge (BAD)
Expand Down
Loading
Loading