Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def bridge_model(self, gpt2_bridge):
def sample_cache(self, bridge_model):
"""Create a sample cache for testing."""
prompt = "The quick brown fox jumps over the lazy dog."
output, cache = bridge_model.run_with_cache(prompt)
output, cache = bridge_model.run_with_cache(input=prompt)
return cache

def test_cache_creation(self, bridge_model):
"""Test that caches can be created from TransformerBridge."""
prompt = "Test cache creation."

output, cache = bridge_model.run_with_cache(prompt, return_cache_object=True)
output, cache = bridge_model.run_with_cache(input=prompt, return_cache_object=True)

assert isinstance(output, torch.Tensor)
assert isinstance(cache, (dict, ActivationCache))
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_cache_with_names_filter(self, bridge_model):
filter_names = list(hook_dict.keys())[:3]

try:
output, cache = bridge_model.run_with_cache(prompt, names_filter=filter_names)
output, cache = bridge_model.run_with_cache(input=prompt, names_filter=filter_names)

if hasattr(cache, "cache_dict"):
cache_dict = cache.cache_dict
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_cache_batch_dimension_handling(self, bridge_model):
prompts = ["First prompt for batch testing.", "Second prompt for batch testing."]

try:
output, cache = bridge_model.run_with_cache(prompts)
output, cache = bridge_model.run_with_cache(input=prompts)

if hasattr(cache, "cache_dict"):
cache_dict = cache.cache_dict
Expand All @@ -173,7 +173,7 @@ def test_cache_device_consistency(self, bridge_model):
prompt = "Test device consistency."

model_cpu = bridge_model.cpu()
output, cache = model_cpu.run_with_cache(prompt)
output, cache = model_cpu.run_with_cache(input=prompt)

if hasattr(cache, "cache_dict"):
cache_dict = cache.cache_dict
Expand All @@ -192,7 +192,7 @@ def test_cache_memory_efficiency(self, bridge_model):
initial_memory = torch.cuda.memory_allocated()

for _ in range(3):
output, cache = bridge_model.run_with_cache(prompt)
output, cache = bridge_model.run_with_cache(input=prompt)
del output, cache

import gc
Expand All @@ -209,10 +209,10 @@ def test_cache_memory_efficiency(self, bridge_model):

def test_cache_with_different_inputs(self, bridge_model):
"""Test that cache works with different input types."""
output1, cache1 = bridge_model.run_with_cache("String input test.")
output1, cache1 = bridge_model.run_with_cache(input="String input test.")

tokens = bridge_model.to_tokens("Token input test.")
output2, cache2 = bridge_model.run_with_cache(tokens)
output2, cache2 = bridge_model.run_with_cache(input=tokens)

assert isinstance(output1, torch.Tensor)
assert isinstance(output2, torch.Tensor)
Expand Down
4 changes: 2 additions & 2 deletions tests/acceptance/model_bridge/test_run_with_cache_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def capture_individual(tensor, hook):

for p in prompts:
gpt2_bridge.run_with_hooks(
p,
input=p,
fwd_hooks=[("blocks.11.hook_resid_post", capture_individual)],
)

Expand All @@ -61,7 +61,7 @@ def capture_batched(tensor, hook):
captured_batched.append(tensor[i, -1, :].detach().clone())

gpt2_bridge.run_with_hooks(
prompts,
input=prompts,
fwd_hooks=[("blocks.11.hook_resid_post", capture_batched)],
)

Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/test_hook_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def hook_fn(tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token

# Run with hooks
out_from_hook = model.run_with_hooks(
prompt,
input=prompt,
prepend_bos=False,
fwd_hooks=[("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token))],
)
Expand Down
4 changes: 2 additions & 2 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def remove_pos_embed(z, hook):
z[:] = 0.0
return z

_ = model.run_with_hooks("Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)])
_ = model.run_with_hooks(input="Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)])

# Check that pos embed has not been permanently changed
assert (model.W_pos == initial_W_pos).all()
Expand All @@ -600,7 +600,7 @@ def edit_pos_embed(z, hook):
return z

_ = model.run_with_hooks(
["Hello, world", "Goodbye, world"],
input=["Hello, world", "Goodbye, world"],
fwd_hooks=[("hook_pos_embed", edit_pos_embed)],
)

Expand Down
10 changes: 5 additions & 5 deletions tests/acceptance/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
gpt2_tokens = model_1_device.to_tokens(gpt2_text)

gpt2_logits_1_device, gpt2_cache_1_device = model_1_device.run_with_cache(
gpt2_tokens, remove_batch_dim=True
input=gpt2_tokens, remove_batch_dim=True
)
gpt2_logits_n_devices, gpt2_cache_n_devices = model_n_devices.run_with_cache(
gpt2_tokens, remove_batch_dim=True
input=gpt2_tokens, remove_batch_dim=True
)

# Make sure the tensors in cache remain on their respective devices
Expand Down Expand Up @@ -106,16 +106,16 @@ def test_load_model_on_target_device():
def test_cache_device():
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1")

logits, cache = model.run_with_cache("Hello there")
logits, cache = model.run_with_cache(input="Hello there")
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(
torch.device("cuda:1")
)

logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu"))
logits, cache = model.run_with_cache(input="Hello there", device=torch.device("cpu"))
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu"))

model.to("cuda")
logits, cache = model.run_with_cache("Hello there")
logits, cache = model.run_with_cache(input="Hello there")
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(logits.device)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,20 @@ class TestCacheBasics:
def test_run_with_cache_returns_nonempty(self, bridge_compat):
"""run_with_cache returns a non-empty cache."""
with torch.no_grad():
_, cache = bridge_compat.run_with_cache("Hello world")
_, cache = bridge_compat.run_with_cache(input="Hello world")
assert len(cache) > 0

def test_cache_contains_residual_hooks(self, bridge_compat):
"""Cache should contain residual stream hooks."""
with torch.no_grad():
_, cache = bridge_compat.run_with_cache("Hello world")
_, cache = bridge_compat.run_with_cache(input="Hello world")
cache_keys = list(cache.keys())
assert any("hook_resid" in k for k in cache_keys)

def test_cache_values_are_tensors(self, bridge_compat):
"""All cached values should be tensors with correct batch dimension."""
with torch.no_grad():
_, cache = bridge_compat.run_with_cache("Hello")
_, cache = bridge_compat.run_with_cache(input="Hello")
for key, value in cache.items():
assert isinstance(value, torch.Tensor), f"Cache[{key}] is {type(value)}"
assert value.shape[0] == 1, f"Cache[{key}] batch dim is {value.shape[0]}"
Expand All @@ -81,9 +81,9 @@ class TestCacheNamesFilter:
def test_names_filter_returns_subset(self, bridge_compat):
"""names_filter should return only matching keys."""
with torch.no_grad():
_, full_cache = bridge_compat.run_with_cache("Hello")
_, full_cache = bridge_compat.run_with_cache(input="Hello")
_, filtered_cache = bridge_compat.run_with_cache(
"Hello",
input="Hello",
names_filter=lambda name: "hook_resid_pre" in name,
)

Expand All @@ -98,7 +98,7 @@ class TestCacheCompleteness:

def test_all_expected_hooks_in_cache(self, bridge_compat):
"""Cache should contain all expected hook names."""
_, cache = bridge_compat.run_with_cache("Hello World!")
_, cache = bridge_compat.run_with_cache(input="Hello World!")
actual_keys = set(cache.keys())
missing = set(EXPECTED_HOOKS) - actual_keys
assert len(missing) == 0, f"Missing expected hooks: {sorted(missing)}"
Expand Down Expand Up @@ -133,8 +133,8 @@ def test_cache_values_match(self, bridge_compat, reference_ht):
Unmasked scores and resulting patterns should still match.
"""
prompt = "Hello World!"
_, bridge_cache = bridge_compat.run_with_cache(prompt)
_, ht_cache = reference_ht.run_with_cache(prompt)
_, bridge_cache = bridge_compat.run_with_cache(input=prompt)
_, ht_cache = reference_ht.run_with_cache(input=prompt)

for hook in EXPECTED_HOOKS:
if hook not in bridge_cache or hook not in ht_cache:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def hook_fn(tensor, hook):
return tensor

bridge.run_with_hooks(
"Hello world",
input="Hello world",
fwd_hooks=[("blocks.0.hook_resid_pre", hook_fn)],
)
assert count == 1
Expand All @@ -58,7 +58,7 @@ def hook_fn(tensor, hook):
return tensor

bridge.run_with_hooks(
"Hello",
input="Hello",
fwd_hooks=[("blocks.0.hook_resid_pre", hook_fn)],
)
assert len(captured["shape"]) >= 2
Expand All @@ -76,7 +76,7 @@ def hook_fn(tensor, hook):
return hook_fn

bridge.run_with_hooks(
"Hello",
input="Hello",
fwd_hooks=[
("blocks.0.hook_resid_pre", make_hook("resid_pre_0")),
("blocks.0.hook_resid_post", make_hook("resid_post_0")),
Expand Down Expand Up @@ -116,7 +116,7 @@ def zero_hook(tensor, hook):
return torch.zeros_like(tensor)

modified_output = bridge.run_with_hooks(
"Hello world",
input="Hello world",
fwd_hooks=[("blocks.0.hook_resid_pre", zero_hook)],
)

Expand All @@ -132,7 +132,7 @@ def ablation_hook(activation, hook):
return activation

ablated_loss = bridge_compat.run_with_hooks(
test_text,
input=test_text,
return_type="loss",
fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)],
)
Expand All @@ -154,14 +154,14 @@ def ablation_hook(activation, hook):

ht_baseline = reference_ht(test_text, return_type="loss")
ht_ablated = reference_ht.run_with_hooks(
test_text,
input=test_text,
return_type="loss",
fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)],
)

bridge_baseline = bridge_compat(test_text, return_type="loss")
bridge_ablated = bridge_compat.run_with_hooks(
test_text,
input=test_text,
return_type="loss",
fwd_hooks=[("blocks.0.attn.hook_v", ablation_hook)],
)
Expand Down Expand Up @@ -190,7 +190,7 @@ def hook_fn(activation, hook):
return hook_fn

bridge_compat.run_with_hooks(
"The quick brown fox",
input="The quick brown fox",
return_type="logits",
fwd_hooks=[("hook_embed", capture("embed"))],
)
Expand All @@ -209,7 +209,7 @@ def hook_fn(activation, hook):
return hook_fn

bridge_compat.run_with_hooks(
"The quick brown fox",
input="The quick brown fox",
return_type="logits",
fwd_hooks=[("blocks.0.attn.hook_v", capture("v"))],
)
Expand Down Expand Up @@ -259,7 +259,7 @@ def hook_fn(tensor, hook):

with torch.no_grad():
bridge.run_with_hooks(
"Hello",
input="Hello",
fwd_hooks=[("blocks.0.hook_resid_pre", hook_fn)],
)
assert count == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def v_ablation_hook(value, hook):
original_loss = model(tokens, return_type="loss")
# Use the correct hook name for Bridge architecture (v.hook_out instead of hook_v)
hooked_loss = model.run_with_hooks(
tokens,
input=tokens,
return_type="loss",
fwd_hooks=[("blocks.0.attn.v.hook_out", v_ablation_hook)],
)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/model_bridge/test_bridge_stop_at_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def count_hook(activation, hook):
# Hook at blocks.0 should fire
# Hook at blocks.1 should NOT fire (stop_at_layer=1)
output = bridge_default.run_with_hooks(
rand_input,
input=rand_input,
stop_at_layer=1,
fwd_hooks=[
("embed.hook_out", count_hook),
Expand Down Expand Up @@ -505,7 +505,7 @@ def count_hook(activation, hook):
# Hook at blocks.0 should fire
# Hook at blocks.1 should NOT fire (stop_at_layer=1)
output = bridge_with_compat_and_processing.run_with_hooks(
rand_input,
input=rand_input,
stop_at_layer=1,
fwd_hooks=[
("embed.hook_out", count_hook),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _h(_m, _i, o):
for i in range(n_layers)
]
with torch.inference_mode():
bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks)
bridge.run_with_hooks(input=tokens, fwd_hooks=fwd_hooks)

for i in range(n_layers):
d = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item()
Expand All @@ -114,7 +114,7 @@ def test_bridge_attention_reconstruction_actually_runs(bridge, tokenize):
tokens = tokenize("Hello, world!")
attn_scores_fired: list[bool] = []
bridge.run_with_hooks(
tokens,
input=tokens,
fwd_hooks=[
("blocks.0.attn.hook_attn_scores", lambda v, hook: attn_scores_fired.append(True)),
],
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/model_bridge/test_mamba_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def zero_resid(t, hook):

with torch.no_grad():
ablated = mamba_bridge.run_with_hooks(
tokens,
input=tokens,
fwd_hooks=[("blocks.12.hook_in", zero_resid)],
)
# Zeroing a mid-layer residual stream should change the output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def fn(tensor, hook):

with torch.no_grad():
bridge.run_with_hooks(
"The quick brown fox", fwd_hooks=[(n, capture(n)) for n in gate_hooks]
input="The quick brown fox", fwd_hooks=[(n, capture(n)) for n in gate_hooks]
)

gate = captured.get("blocks.1.attn.hook_q_gate")
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/model_bridge/test_smollm3_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _hook(_module, _inputs, output):
for i in range(n_layers)
]
with torch.inference_mode():
bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks)
bridge.run_with_hooks(input=tokens, fwd_hooks=fwd_hooks)

for i, layer in enumerate(hf_eager.model.layers):
drift = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item()
Expand All @@ -163,7 +163,7 @@ def test_bridge_runs_its_own_attention_reconstruction(
"""
fired: list[bool] = []
bridge.run_with_hooks(
tokens,
input=tokens,
fwd_hooks=[
("blocks.0.attn.hook_attn_scores", lambda value, hook: fired.append(True)),
],
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/model_bridge/test_weight_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def ablation_hook(activation, hook):
return activation

ref_ablated_loss = reference_ht.run_with_hooks(
test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)]
input=test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)]
)
ref_ablation_effect = ref_ablated_loss - ref_loss

Expand All @@ -115,7 +115,7 @@ def ablation_hook(activation, hook):

# Test ablation with bridge
bridge_ablated_loss = bridge.run_with_hooks(
test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)]
input=test_text, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)]
)
bridge_ablation_effect = bridge_ablated_loss - bridge_loss

Expand Down Expand Up @@ -193,7 +193,7 @@ def ablation_hook(
hook_name = utils.get_act_name("v", layer)
orig = model(tokens, return_type="loss").item()
ablated = model.run_with_hooks(
tokens, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)]
input=tokens, return_type="loss", fwd_hooks=[(hook_name, ablation_hook)]
).item()
return orig, ablated

Expand Down
Loading
Loading