Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions tests/integration/model_bridge/test_deepseek_v2_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Integration tests for DeepSeek V2 architecture adapter.

Covers two distinct variants of DeepseekV2ForCausalLM:
- V2-full (q_lora_rank set): Q is compressed via two-stage LoRA projection.
- V2-Lite (q_lora_rank=None): Q uses a direct linear projection; no compression.
"""

import tempfile

import pytest
import torch
from transformers import AutoTokenizer, DeepseekV2Config, DeepseekV2ForCausalLM

from transformer_lens.model_bridge.bridge import TransformerBridge


def _make_bridge(q_lora_rank):
"""Build a tiny DeepseekV2 bridge with the given q_lora_rank (None = V2-Lite)."""
cfg = DeepseekV2Config(
hidden_size=256,
intermediate_size=512,
num_hidden_layers=4,
num_attention_heads=8,
q_lora_rank=q_lora_rank,
kv_lora_rank=32,
qk_nope_head_dim=16,
qk_rope_head_dim=8,
v_head_dim=16,
vocab_size=1000,
first_k_dense_replace=1,
n_routed_experts=8,
n_shared_experts=2,
num_experts_per_tok=2,
max_position_embeddings=128,
moe_intermediate_size=256,
)
hf_model = DeepseekV2ForCausalLM(cfg)
with tempfile.TemporaryDirectory() as tmpdir:
hf_model.save_pretrained(tmpdir)
tok = AutoTokenizer.from_pretrained("gpt2")
tok.save_pretrained(tmpdir)
return TransformerBridge.boot_transformers(tmpdir, device="cpu")


@pytest.fixture(scope="module")
def tiny_deepseek_v2_bridge():
"""V2-full: q_lora_rank=64 — two-stage Q compression (same as V3)."""
return _make_bridge(q_lora_rank=64)


@pytest.fixture(scope="module")
def tiny_deepseek_v2_lite_bridge():
"""V2-Lite: q_lora_rank=None — direct Q projection, no LoRA compression."""
return _make_bridge(q_lora_rank=None)


# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------


def _tokens():
return torch.tensor([[1, 2, 3, 4]])


# ---------------------------------------------------------------------------
# V2-full tests
# ---------------------------------------------------------------------------


class TestDeepSeekV2BridgeCreation:
def test_block_count(self, tiny_deepseek_v2_bridge):
assert len(tiny_deepseek_v2_bridge.blocks) == 4

def test_has_embed_unembed_ln_final(self, tiny_deepseek_v2_bridge):
assert hasattr(tiny_deepseek_v2_bridge, "embed")
assert hasattr(tiny_deepseek_v2_bridge, "unembed")
assert hasattr(tiny_deepseek_v2_bridge, "ln_final")

def test_attention_is_mla(self, tiny_deepseek_v2_bridge):
from transformer_lens.model_bridge.generalized_components.mla_attention import (
MLAAttentionBridge,
)

assert isinstance(tiny_deepseek_v2_bridge.blocks[0].attn, MLAAttentionBridge)


class TestDeepSeekV2ForwardPass:
def test_forward_returns_correct_shape(self, tiny_deepseek_v2_bridge):
tokens = _tokens()
with torch.no_grad():
out = tiny_deepseek_v2_bridge(tokens)
assert out.shape == (1, 4, 1000)
assert not torch.isnan(out).any()
assert not torch.isinf(out).any()

def test_forward_matches_hf(self, tiny_deepseek_v2_bridge):
tokens = _tokens()
hf_model = tiny_deepseek_v2_bridge.original_model
with torch.no_grad():
bridge_out = tiny_deepseek_v2_bridge(tokens)
hf_out = hf_model(tokens).logits
max_diff = (bridge_out - hf_out).abs().max().item()
assert max_diff < 0.15, f"Bridge vs HF max diff = {max_diff}"


class TestDeepSeekV2DenseVsMoELayers:
def test_dense_layer_has_no_moe_hooks(self, tiny_deepseek_v2_bridge):
_, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens())
assert not any("blocks.0.mlp.gate" in k for k in cache)
assert not any("blocks.0.mlp.shared_experts" in k for k in cache)

def test_moe_layer_has_shared_expert_hooks(self, tiny_deepseek_v2_bridge):
# DeepseekV2Moe.forward() routes via nn.functional.linear(..., self.gate.weight)
# directly — not self.gate(hidden_states) — so the gate module's forward() is
# never called and its bridge hooks cannot fire. shared_experts IS called via
# __call__, so GatedMLPBridge hooks fire correctly.
_, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens())
assert not any(
"blocks.1.mlp.gate" in k for k in cache
), "gate hooks should not appear — gate is called via functional.linear, not forward()"
assert any("blocks.1.mlp.shared_experts" in k for k in cache)

def test_all_layers_have_mlp_hooks(self, tiny_deepseek_v2_bridge):
_, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens())
for i in range(4):
assert f"blocks.{i}.mlp.hook_in" in cache
assert f"blocks.{i}.mlp.hook_out" in cache
assert not torch.isnan(cache[f"blocks.{i}.mlp.hook_out"]).any()


class TestDeepSeekV2AttentionHooks:
def test_attn_hooks_fire_all_layers(self, tiny_deepseek_v2_bridge):
_, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens())
for i in range(4):
assert f"blocks.{i}.attn.hook_in" in cache
assert f"blocks.{i}.attn.hook_out" in cache

def test_mla_latent_hooks_fire(self, tiny_deepseek_v2_bridge):
_, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens())
assert any("hook_q_latent" in k for k in cache)
assert any("hook_kv_latent" in k for k in cache)


# ---------------------------------------------------------------------------
# V2-Lite tests (q_lora_rank=None — direct q_proj, no compression)
# ---------------------------------------------------------------------------


class TestDeepSeekV2LiteBridgeCreation:
def test_block_count(self, tiny_deepseek_v2_lite_bridge):
assert len(tiny_deepseek_v2_lite_bridge.blocks) == 4

def test_attention_is_mla(self, tiny_deepseek_v2_lite_bridge):
from transformer_lens.model_bridge.generalized_components.mla_attention import (
MLAAttentionBridge,
)

assert isinstance(tiny_deepseek_v2_lite_bridge.blocks[0].attn, MLAAttentionBridge)


class TestDeepSeekV2LiteForwardPass:
def test_forward_returns_correct_shape(self, tiny_deepseek_v2_lite_bridge):
tokens = _tokens()
with torch.no_grad():
out = tiny_deepseek_v2_lite_bridge(tokens)
assert out.shape == (1, 4, 1000)
assert not torch.isnan(out).any()
assert not torch.isinf(out).any()

def test_forward_matches_hf(self, tiny_deepseek_v2_lite_bridge):
tokens = _tokens()
hf_model = tiny_deepseek_v2_lite_bridge.original_model
with torch.no_grad():
bridge_out = tiny_deepseek_v2_lite_bridge(tokens)
hf_out = hf_model(tokens).logits
max_diff = (bridge_out - hf_out).abs().max().item()
assert max_diff < 0.15, f"V2-Lite bridge vs HF max diff = {max_diff}"


class TestDeepSeekV2LiteNoQLatentHook:
def test_hook_q_latent_absent_without_q_lora_rank(self, tiny_deepseek_v2_lite_bridge):
"""V2-Lite skips Q compression — hook_q_latent should not fire."""
_, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens())
assert not any("hook_q_latent" in k for k in cache)

def test_hook_kv_latent_still_fires(self, tiny_deepseek_v2_lite_bridge):
"""KV compression is always present regardless of q_lora_rank."""
_, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens())
assert any("hook_kv_latent" in k for k in cache)

def test_all_layers_produce_non_nan(self, tiny_deepseek_v2_lite_bridge):
_, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens())
for i in range(4):
assert not torch.isnan(cache[f"blocks.{i}.attn.hook_out"]).any()
2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BloomArchitectureAdapter,
CodeGenArchitectureAdapter,
CohereArchitectureAdapter,
DeepSeekV2ArchitectureAdapter,
DeepSeekV3ArchitectureAdapter,
FalconArchitectureAdapter,
Gemma1ArchitectureAdapter,
Expand Down Expand Up @@ -78,6 +79,7 @@
"BloomForCausalLM": BloomArchitectureAdapter,
"CodeGenForCausalLM": CodeGenArchitectureAdapter,
"CohereForCausalLM": CohereArchitectureAdapter,
"DeepseekV2ForCausalLM": DeepSeekV2ArchitectureAdapter,
"DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter,
"FalconForCausalLM": FalconArchitectureAdapter,
"GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def _apply_rotary_pos_emb(
return q_embed, k_embed


def _apply_rotary_complex(
q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embedding via complex multiplication (DeepSeek-V2 style).

DeepSeek-V2 uses ``freqs_cis = torch.polar(ones, freqs)`` (complex exponentials)
instead of the standard (cos, sin) pair. This matches the V2 HF implementation of
``apply_rotary_emb``.

Args:
q: Query rope portion [batch, heads, seq, rope_dim].
k: Key rope portion [batch, 1, seq, rope_dim].
freqs_cis: Complex rotary frequencies [batch, seq, rope_dim // 2].

Returns:
Tuple of rotated (q, k) tensors with same dtype and shape as inputs.
"""
freqs = freqs_cis.unsqueeze(1) # [batch, 1, seq, rope_dim // 2]
q_c = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
k_c = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
q_rot = torch.view_as_real(q_c * freqs.to(q_c.device)).flatten(3).type_as(q)
k_rot = torch.view_as_real(k_c * freqs.to(k_c.device)).flatten(3).type_as(k)
return q_rot, k_rot


class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge):
"""Bridge for DeepSeek's Multi-Head Latent Attention (MLA).

Expand Down Expand Up @@ -176,20 +201,31 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim)

# --- RoPE ---
# DeepSeek-V2 passes a complex freqs_cis tensor; V3 passes a (cos, sin) tuple.
# Detect the format and apply the appropriate rotation.
cos = sin = None
if position_embeddings is not None:
position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
cos, sin = position_embeddings
if isinstance(position_embeddings, torch.Tensor) and position_embeddings.is_complex():
# V2-style: complex exponential freqs_cis
q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, position_embeddings)
else:
cos, sin = position_embeddings
q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
elif self._rotary_emb is not None:
# Fallback: compute from rotary_emb if position_embeddings not passed
position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0)
cos, sin = self._rotary_emb(hidden_states, position_ids)
emb = self._rotary_emb(hidden_states, position_ids)
if isinstance(emb, torch.Tensor) and emb.is_complex():
q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, emb)
else:
cos, sin = emb
q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
else:
raise ValueError(
"MLAAttentionBridge requires position_embeddings or set_rotary_emb() "
"to be called before forward."
)

q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
q_rot = self.hook_rot_q(q_rot)
k_rot = self.hook_rot_k(k_rot)

Expand All @@ -209,7 +245,11 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
past_key_values = kwargs.pop("past_key_values", None)
cache_position = kwargs.pop("cache_position", None)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs: dict = {"cache_position": cache_position}
if cos is not None:
cache_kwargs["cos"] = cos
if sin is not None:
cache_kwargs["sin"] = sin
key_states, value_states = past_key_values.update(
key_states, value_states, hf_attn.layer_idx, cache_kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This module contains the bridge component for rotary position embedding layers.
"""
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -81,7 +81,9 @@ def get_random_inputs(
args = (x, position_ids, layer_type)
return {"args": args}

def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self, *args: Any, **kwargs: Any
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Forward pass through the rotary embedding bridge.

Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors.
Expand All @@ -94,7 +96,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor

Returns:
Tuple of (cos, sin) tensors for rotary position embeddings, after being
passed through hook_cos and hook_sin respectively
passed through hook_cos and hook_sin respectively. For DeepSeek-V2-style
embeddings that return a single complex ``freqs_cis`` tensor, that tensor is
passed through unchanged for downstream complex multiplication.
"""
if self.original_component is None:
raise RuntimeError(
Expand All @@ -109,8 +113,13 @@ def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor
# Call original component to get (cos, sin) tuple
output = self.original_component(*args, **kwargs)

# Ensure output is a tuple
# Ensure output is a tuple — or a complex tensor (DeepSeek-V2 freqs_cis style)
if not isinstance(output, tuple):
if isinstance(output, torch.Tensor) and output.is_complex():
# V2-style: freqs_cis complex tensor — pass through without cos/sin split.
# hook_cos/hook_sin do not apply here; the complex form is consumed by
# MLAAttentionBridge which detects it and uses complex multiplication.
return output
if hasattr(output, "__iter__") and (not isinstance(output, torch.Tensor)):
output = tuple(output)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from transformer_lens.model_bridge.supported_architectures.cohere import (
CohereArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.deepseek_v2 import (
DeepSeekV2ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.deepseek_v3 import (
DeepSeekV3ArchitectureAdapter,
)
Expand Down Expand Up @@ -182,6 +185,7 @@
"BloomArchitectureAdapter",
"CodeGenArchitectureAdapter",
"CohereArchitectureAdapter",
"DeepSeekV2ArchitectureAdapter",
"DeepSeekV3ArchitectureAdapter",
"FalconArchitectureAdapter",
"Gemma1ArchitectureAdapter",
Expand Down
Loading
Loading