diff --git a/tests/integration/model_bridge/test_deepseek_v2_adapter.py b/tests/integration/model_bridge/test_deepseek_v2_adapter.py new file mode 100644 index 000000000..bd74bfa82 --- /dev/null +++ b/tests/integration/model_bridge/test_deepseek_v2_adapter.py @@ -0,0 +1,195 @@ +"""Integration tests for DeepSeek V2 architecture adapter. + +Covers two distinct variants of DeepseekV2ForCausalLM: +- V2-full (q_lora_rank set): Q is compressed via two-stage LoRA projection. +- V2-Lite (q_lora_rank=None): Q uses a direct linear projection; no compression. +""" + +import tempfile + +import pytest +import torch +from transformers import AutoTokenizer, DeepseekV2Config, DeepseekV2ForCausalLM + +from transformer_lens.model_bridge.bridge import TransformerBridge + + +def _make_bridge(q_lora_rank): + """Build a tiny DeepseekV2 bridge with the given q_lora_rank (None = V2-Lite).""" + cfg = DeepseekV2Config( + hidden_size=256, + intermediate_size=512, + num_hidden_layers=4, + num_attention_heads=8, + q_lora_rank=q_lora_rank, + kv_lora_rank=32, + qk_nope_head_dim=16, + qk_rope_head_dim=8, + v_head_dim=16, + vocab_size=1000, + first_k_dense_replace=1, + n_routed_experts=8, + n_shared_experts=2, + num_experts_per_tok=2, + max_position_embeddings=128, + moe_intermediate_size=256, + ) + hf_model = DeepseekV2ForCausalLM(cfg) + with tempfile.TemporaryDirectory() as tmpdir: + hf_model.save_pretrained(tmpdir) + tok = AutoTokenizer.from_pretrained("gpt2") + tok.save_pretrained(tmpdir) + return TransformerBridge.boot_transformers(tmpdir, device="cpu") + + +@pytest.fixture(scope="module") +def tiny_deepseek_v2_bridge(): + """V2-full: q_lora_rank=64 — two-stage Q compression (same as V3).""" + return _make_bridge(q_lora_rank=64) + + +@pytest.fixture(scope="module") +def tiny_deepseek_v2_lite_bridge(): + """V2-Lite: q_lora_rank=None — direct Q projection, no LoRA compression.""" + return _make_bridge(q_lora_rank=None) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _tokens(): + return torch.tensor([[1, 2, 3, 4]]) + + +# --------------------------------------------------------------------------- +# V2-full tests +# --------------------------------------------------------------------------- + + +class TestDeepSeekV2BridgeCreation: + def test_block_count(self, tiny_deepseek_v2_bridge): + assert len(tiny_deepseek_v2_bridge.blocks) == 4 + + def test_has_embed_unembed_ln_final(self, tiny_deepseek_v2_bridge): + assert hasattr(tiny_deepseek_v2_bridge, "embed") + assert hasattr(tiny_deepseek_v2_bridge, "unembed") + assert hasattr(tiny_deepseek_v2_bridge, "ln_final") + + def test_attention_is_mla(self, tiny_deepseek_v2_bridge): + from transformer_lens.model_bridge.generalized_components.mla_attention import ( + MLAAttentionBridge, + ) + + assert isinstance(tiny_deepseek_v2_bridge.blocks[0].attn, MLAAttentionBridge) + + +class TestDeepSeekV2ForwardPass: + def test_forward_returns_correct_shape(self, tiny_deepseek_v2_bridge): + tokens = _tokens() + with torch.no_grad(): + out = tiny_deepseek_v2_bridge(tokens) + assert out.shape == (1, 4, 1000) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_forward_matches_hf(self, tiny_deepseek_v2_bridge): + tokens = _tokens() + hf_model = tiny_deepseek_v2_bridge.original_model + with torch.no_grad(): + bridge_out = tiny_deepseek_v2_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 0.15, f"Bridge vs HF max diff = {max_diff}" + + +class TestDeepSeekV2DenseVsMoELayers: + def test_dense_layer_has_no_moe_hooks(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + assert not any("blocks.0.mlp.gate" in k for k in cache) + assert not any("blocks.0.mlp.shared_experts" in k for k in cache) + + def test_moe_layer_has_shared_expert_hooks(self, tiny_deepseek_v2_bridge): + # DeepseekV2Moe.forward() routes via nn.functional.linear(..., self.gate.weight) + # directly — not self.gate(hidden_states) — so the gate module's forward() is + # never called and its bridge hooks cannot fire. shared_experts IS called via + # __call__, so GatedMLPBridge hooks fire correctly. + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + assert not any( + "blocks.1.mlp.gate" in k for k in cache + ), "gate hooks should not appear — gate is called via functional.linear, not forward()" + assert any("blocks.1.mlp.shared_experts" in k for k in cache) + + def test_all_layers_have_mlp_hooks(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + for i in range(4): + assert f"blocks.{i}.mlp.hook_in" in cache + assert f"blocks.{i}.mlp.hook_out" in cache + assert not torch.isnan(cache[f"blocks.{i}.mlp.hook_out"]).any() + + +class TestDeepSeekV2AttentionHooks: + def test_attn_hooks_fire_all_layers(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + for i in range(4): + assert f"blocks.{i}.attn.hook_in" in cache + assert f"blocks.{i}.attn.hook_out" in cache + + def test_mla_latent_hooks_fire(self, tiny_deepseek_v2_bridge): + _, cache = tiny_deepseek_v2_bridge.run_with_cache(_tokens()) + assert any("hook_q_latent" in k for k in cache) + assert any("hook_kv_latent" in k for k in cache) + + +# --------------------------------------------------------------------------- +# V2-Lite tests (q_lora_rank=None — direct q_proj, no compression) +# --------------------------------------------------------------------------- + + +class TestDeepSeekV2LiteBridgeCreation: + def test_block_count(self, tiny_deepseek_v2_lite_bridge): + assert len(tiny_deepseek_v2_lite_bridge.blocks) == 4 + + def test_attention_is_mla(self, tiny_deepseek_v2_lite_bridge): + from transformer_lens.model_bridge.generalized_components.mla_attention import ( + MLAAttentionBridge, + ) + + assert isinstance(tiny_deepseek_v2_lite_bridge.blocks[0].attn, MLAAttentionBridge) + + +class TestDeepSeekV2LiteForwardPass: + def test_forward_returns_correct_shape(self, tiny_deepseek_v2_lite_bridge): + tokens = _tokens() + with torch.no_grad(): + out = tiny_deepseek_v2_lite_bridge(tokens) + assert out.shape == (1, 4, 1000) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_forward_matches_hf(self, tiny_deepseek_v2_lite_bridge): + tokens = _tokens() + hf_model = tiny_deepseek_v2_lite_bridge.original_model + with torch.no_grad(): + bridge_out = tiny_deepseek_v2_lite_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 0.15, f"V2-Lite bridge vs HF max diff = {max_diff}" + + +class TestDeepSeekV2LiteNoQLatentHook: + def test_hook_q_latent_absent_without_q_lora_rank(self, tiny_deepseek_v2_lite_bridge): + """V2-Lite skips Q compression — hook_q_latent should not fire.""" + _, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens()) + assert not any("hook_q_latent" in k for k in cache) + + def test_hook_kv_latent_still_fires(self, tiny_deepseek_v2_lite_bridge): + """KV compression is always present regardless of q_lora_rank.""" + _, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens()) + assert any("hook_kv_latent" in k for k in cache) + + def test_all_layers_produce_non_nan(self, tiny_deepseek_v2_lite_bridge): + _, cache = tiny_deepseek_v2_lite_bridge.run_with_cache(_tokens()) + for i in range(4): + assert not torch.isnan(cache[f"blocks.{i}.attn.hook_out"]).any() diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 49dd134f7..8869a7be1 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -16,6 +16,7 @@ BloomArchitectureAdapter, CodeGenArchitectureAdapter, CohereArchitectureAdapter, + DeepSeekV2ArchitectureAdapter, DeepSeekV3ArchitectureAdapter, FalconArchitectureAdapter, Gemma1ArchitectureAdapter, @@ -78,6 +79,7 @@ "BloomForCausalLM": BloomArchitectureAdapter, "CodeGenForCausalLM": CodeGenArchitectureAdapter, "CohereForCausalLM": CohereArchitectureAdapter, + "DeepseekV2ForCausalLM": DeepSeekV2ArchitectureAdapter, "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter, "FalconForCausalLM": FalconArchitectureAdapter, "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version diff --git a/transformer_lens/model_bridge/generalized_components/mla_attention.py b/transformer_lens/model_bridge/generalized_components/mla_attention.py index 18a770480..f472e1bd1 100644 --- a/transformer_lens/model_bridge/generalized_components/mla_attention.py +++ b/transformer_lens/model_bridge/generalized_components/mla_attention.py @@ -47,6 +47,31 @@ def _apply_rotary_pos_emb( return q_embed, k_embed +def _apply_rotary_complex( + q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding via complex multiplication (DeepSeek-V2 style). + + DeepSeek-V2 uses ``freqs_cis = torch.polar(ones, freqs)`` (complex exponentials) + instead of the standard (cos, sin) pair. This matches the V2 HF implementation of + ``apply_rotary_emb``. + + Args: + q: Query rope portion [batch, heads, seq, rope_dim]. + k: Key rope portion [batch, 1, seq, rope_dim]. + freqs_cis: Complex rotary frequencies [batch, seq, rope_dim // 2]. + + Returns: + Tuple of rotated (q, k) tensors with same dtype and shape as inputs. + """ + freqs = freqs_cis.unsqueeze(1) # [batch, 1, seq, rope_dim // 2] + q_c = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2)) + k_c = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2)) + q_rot = torch.view_as_real(q_c * freqs.to(q_c.device)).flatten(3).type_as(q) + k_rot = torch.view_as_real(k_c * freqs.to(k_c.device)).flatten(3).type_as(k) + return q_rot, k_rot + + class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): """Bridge for DeepSeek's Multi-Head Latent Attention (MLA). @@ -176,20 +201,31 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: k_rot = k_rot.view(batch_size, 1, seq_length, self._qk_rope_head_dim) # --- RoPE --- + # DeepSeek-V2 passes a complex freqs_cis tensor; V3 passes a (cos, sin) tuple. + # Detect the format and apply the appropriate rotation. + cos = sin = None if position_embeddings is not None: position_embeddings = self._apply_position_embedding_hooks(position_embeddings) - cos, sin = position_embeddings + if isinstance(position_embeddings, torch.Tensor) and position_embeddings.is_complex(): + # V2-style: complex exponential freqs_cis + q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, position_embeddings) + else: + cos, sin = position_embeddings + q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) elif self._rotary_emb is not None: # Fallback: compute from rotary_emb if position_embeddings not passed position_ids = torch.arange(seq_length, device=hidden_states.device).unsqueeze(0) - cos, sin = self._rotary_emb(hidden_states, position_ids) + emb = self._rotary_emb(hidden_states, position_ids) + if isinstance(emb, torch.Tensor) and emb.is_complex(): + q_rot, k_rot = _apply_rotary_complex(q_rot, k_rot, emb) + else: + cos, sin = emb + q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) else: raise ValueError( "MLAAttentionBridge requires position_embeddings or set_rotary_emb() " "to be called before forward." ) - - q_rot, k_rot = _apply_rotary_pos_emb(q_rot, k_rot, cos, sin) q_rot = self.hook_rot_q(q_rot) k_rot = self.hook_rot_k(k_rot) @@ -209,7 +245,11 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: past_key_values = kwargs.pop("past_key_values", None) cache_position = kwargs.pop("cache_position", None) if past_key_values is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs: dict = {"cache_position": cache_position} + if cos is not None: + cache_kwargs["cos"] = cos + if sin is not None: + cache_kwargs["sin"] = sin key_states, value_states = past_key_values.update( key_states, value_states, hf_attn.layer_idx, cache_kwargs ) diff --git a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py index 3af922a04..c560fff96 100644 --- a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py +++ b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py @@ -2,7 +2,7 @@ This module contains the bridge component for rotary position embedding layers. """ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch @@ -81,7 +81,9 @@ def get_random_inputs( args = (x, position_ids, layer_type) return {"args": args} - def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, *args: Any, **kwargs: Any + ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Forward pass through the rotary embedding bridge. Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. @@ -94,7 +96,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor Returns: Tuple of (cos, sin) tensors for rotary position embeddings, after being - passed through hook_cos and hook_sin respectively + passed through hook_cos and hook_sin respectively. For DeepSeek-V2-style + embeddings that return a single complex ``freqs_cis`` tensor, that tensor is + passed through unchanged for downstream complex multiplication. """ if self.original_component is None: raise RuntimeError( @@ -109,8 +113,13 @@ def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor # Call original component to get (cos, sin) tuple output = self.original_component(*args, **kwargs) - # Ensure output is a tuple + # Ensure output is a tuple — or a complex tensor (DeepSeek-V2 freqs_cis style) if not isinstance(output, tuple): + if isinstance(output, torch.Tensor) and output.is_complex(): + # V2-style: freqs_cis complex tensor — pass through without cos/sin split. + # hook_cos/hook_sin do not apply here; the complex form is consumed by + # MLAAttentionBridge which detects it and uses complex multiplication. + return output if hasattr(output, "__iter__") and (not isinstance(output, torch.Tensor)): output = tuple(output) else: diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 84e4584af..0d76a7f76 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -21,6 +21,9 @@ from transformer_lens.model_bridge.supported_architectures.cohere import ( CohereArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.deepseek_v2 import ( + DeepSeekV2ArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.deepseek_v3 import ( DeepSeekV3ArchitectureAdapter, ) @@ -182,6 +185,7 @@ "BloomArchitectureAdapter", "CodeGenArchitectureAdapter", "CohereArchitectureAdapter", + "DeepSeekV2ArchitectureAdapter", "DeepSeekV3ArchitectureAdapter", "FalconArchitectureAdapter", "Gemma1ArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py b/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py new file mode 100644 index 000000000..76c977a54 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/deepseek_v2.py @@ -0,0 +1,125 @@ +"""DeepSeek V2 architecture adapter. + +Supports DeepSeek-V2, DeepSeek-V2-Lite, and DeepSeek-Coder-V2 models +(all use DeepseekV2ForCausalLM). + +Key features: +- Multi-Head Latent Attention (MLA): Q and KV compressed via LoRA-style projections. + DeepSeek-V2-Lite sets q_lora_rank=None, skipping Q compression and using a direct + q_proj instead — MLAAttentionBridge.forward handles both paths automatically. +- Mixture of Experts (MoE) with shared experts on most layers +- Dense MLP on first `first_k_dense_replace` layers +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + MLAAttentionBridge, + MLABlockBridge, + MoEBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) + + +class DeepSeekV2ArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for DeepSeek V2 / V2-Lite / Coder-V2 models. + + Uses RMSNorm, MLA with compressed Q/KV projections (or direct Q projection + when q_lora_rank is None), partial RoPE, MoE on most layers (dense MLP on + first few), and no biases. + """ + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.gated_mlp = True + self.cfg.final_rms = True + self.cfg.uses_rms_norm = True + + self.weight_processing_conversions = {} + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), + "blocks": MLABlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "attn": MLAAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + # V2-full (q_lora_rank set): two-stage LoRA Q compression. + # These are absent in V2-Lite — marked optional so bridge + # setup skips them gracefully. The actual forward call is + # handled inside MLAAttentionBridge which checks q_lora_rank. + "q_a_proj": LinearBridge(name="q_a_proj", optional=True), + # q_a_layernorm is a norm inside the attention block; its + # forward is called directly by MLAAttentionBridge, so a + # plain GeneralizedComponent (with optional support) suffices. + "q_a_layernorm": GeneralizedComponent( + name="q_a_layernorm", optional=True + ), + "q_b_proj": LinearBridge(name="q_b_proj", optional=True), + # V2-Lite only: direct Q projection, no compression. + "q_proj": LinearBridge(name="q_proj", optional=True), + # KV path — always present across all V2 variants. + "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"), + "kv_a_layernorm": RMSNormalizationBridge( + name="kv_a_layernorm", config=self.cfg + ), + "kv_b_proj": LinearBridge(name="kv_b_proj"), + "o": LinearBridge(name="o_proj"), + }, + ), + # On dense layers (idx < first_k_dense_replace), shared_experts + # are absent — marked optional so setup gracefully skips them when + # the layer is DeepseekV2MLP instead of MoE. + # Note: the gate module is NOT bridged — DeepseekV2Moe.forward() + # calls nn.functional.linear(..., self.gate.weight) directly, + # bypassing forward(), so no hook can be attached to it. + "mlp": MoEBridge( + name="mlp", + config=self.cfg, + submodules={ + "shared_experts": GatedMLPBridge( + name="shared_experts", + config=self.cfg, + optional=True, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up rotary embedding references for component testing.""" + rotary_emb = hf_model.model.rotary_emb + + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index 769b9b0d1..e08aa7cba 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -52,6 +52,7 @@ "BloomForCausalLM", "CodeGenForCausalLM", "CohereForCausalLM", + "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "FalconForCausalLM", "GemmaForCausalLM", @@ -112,6 +113,7 @@ "BloomForCausalLM": ["bigscience"], "CodeGenForCausalLM": ["Salesforce"], "CohereForCausalLM": ["CohereLabs"], + "DeepseekV2ForCausalLM": ["deepseek-ai"], "DeepseekV3ForCausalLM": ["deepseek-ai"], "FalconForCausalLM": ["tiiuae"], "Gemma2ForCausalLM": ["google"],