diff --git a/apriel2-vllm-plugin/pyproject.toml b/apriel2-vllm-plugin/pyproject.toml new file mode 100644 index 000000000..e6a4dc07e --- /dev/null +++ b/apriel2-vllm-plugin/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + +[project] +name = "apriel2-vllm-plugin" +version = "0.1.0" +description = "Standalone vLLM plugin for Apriel2 models (extracted from Fast-LLM)" +requires-python = ">=3.12" +dependencies = [ + "torch", + "transformers", + "einops", +] + +[project.entry-points."vllm.general_plugins"] +apriel2 = "fast_llm_external_models.apriel2.vllm.config_convertor:register" + +[tool.setuptools.packages.find] +where = [".."] +include = [ + "fast_llm_external_models", + "fast_llm_external_models.apriel2", + "fast_llm_external_models.apriel2.vllm", +] + +[tool.setuptools.package-dir] +"" = ".." diff --git a/fast_llm_external_models/apriel2/vllm/config_convertor.py b/fast_llm_external_models/apriel2/vllm/config_convertor.py index 02bbecb7a..86c890559 100644 --- a/fast_llm_external_models/apriel2/vllm/config_convertor.py +++ b/fast_llm_external_models/apriel2/vllm/config_convertor.py @@ -9,6 +9,12 @@ """ from vllm import ModelRegistry +from vllm.model_executor.models.config import ( + MODELS_CONFIG_MAP, + HybridAttentionMambaModelConfig, + MambaModelConfig, + VerifyAndUpdateConfig, +) from vllm.transformers_utils.model_arch_config_convertor import ( MODEL_ARCH_CONFIG_CONVERTORS, ModelArchConfigConvertorBase, @@ -66,6 +72,39 @@ def get_head_size(self) -> int: return self._get_first_attention_block().get("head_size", 0) +class Apriel2ModelConfig(VerifyAndUpdateConfig): + """Config handler for Apriel2 models with heterogeneous mixer types. + + Apriel2 can be pure-attention, pure-mamba, or hybrid (attention + mamba) + depending on the decoder config. vLLM's default ``is_hybrid`` dispatch + calls ``HybridAttentionMambaModelConfig`` which crashes for pure-mamba + models (``ZeroDivisionError`` when ``num_kv_heads=0``). + + This handler inspects ``layers_block_type`` on the HF config to determine + the model composition and routes to the correct config handler. + """ + + @staticmethod + def verify_and_update_config(vllm_config) -> None: + hf_config = vllm_config.model_config.hf_config + layer_types = getattr(hf_config, "layers_block_type", None) + + if layer_types is None: + # Fallback: no layer type info — assume standard transformer. + return + + has_attention = any(t == "attention" for t in layer_types) + has_mamba = any(t == "mamba" for t in layer_types) + + if has_attention and has_mamba: + # Hybrid: attention + mamba page size alignment required. + HybridAttentionMambaModelConfig.verify_and_update_config(vllm_config) + elif has_mamba: + # Pure mamba: enable FULL_AND_PIECEWISE, set mamba_block_size. + MambaModelConfig.verify_and_update_config(vllm_config) + # Pure attention: no special config needed. + + def register(): """Register Apriel2 models and config convertors with vLLM. @@ -126,7 +165,7 @@ def register(): # Best-effort only; vLLM can still proceed with the generic config. pass - # Register model class + # Register model class and config handler. # Note: some exported checkpoints may list "Apriel2ForConditionalGeneration" # in config.json's "architectures". vLLM's model selection is driven by that # field, so we alias it to the same vLLM implementation for text-only usage. @@ -135,3 +174,8 @@ def register(): arch, "fast_llm_external_models.apriel2.vllm:Apriel2ForCausalLM", ) + # Register in MODELS_CONFIG_MAP so vLLM calls our handler instead of + # relying on the is_hybrid class attribute dispatch (which can't handle + # models that are sometimes hybrid, sometimes pure-mamba). + if arch not in MODELS_CONFIG_MAP: + MODELS_CONFIG_MAP[arch] = Apriel2ModelConfig diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 5ab525f61..6049f913a 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -16,6 +16,7 @@ import torch import triton +import vllm.compilation.monitor as _compile_monitor import vllm.model_executor.layers.kda # noqa: F401 from einops import rearrange from torch import nn @@ -88,19 +89,72 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i apriel2_logger = init_logger(__name__) -# ============================================================================= -# Debug Flags -# ============================================================================= -# Top-level debug flags that control all debug output in the module. -# Set these to True to enable debugging for specific components. -DEBUG_GDN_LAYER = False # Debug GDN layer forward pass (tensors, shapes) -DEBUG_GDN_STATE = False # Debug GDN recurrent state during decode -DEBUG_GDN_OUTPUT = False # Debug GDN output hidden states during decode -DEBUG_KDA_LAYER = False # Debug KDA layer outputs -DEBUG_DECODER_LAYER = False # Debug decoder layer outputs (residual, norm) -DEBUG_FINAL_NORM = False # Debug final norm before LM head -DEBUG_LM_HEAD = False # Debug LM head input/output +# --------------------------------------------------------------------------- +# Monkey-patch: fix vLLM's KV cache group_size for heterogeneous hybrid models +# --------------------------------------------------------------------------- +# vLLM's _get_kv_cache_groups_uniform_page_size computes +# group_size = min(num_layers_per_type) +# which was designed for models with simple n:1 repeating patterns (Gemma3, +# LLaMA4). For Apriel2 supernets with arbitrary placement ratios (e.g. 12 +# attention + 1 GDN + 11 KDA), min = 1 creates one KV-cache group per layer +# (24 total). Each group triggers a metadata rebuild per forward step, +# causing >2× throughput regression on small models. +# +# The fix: use max(num_layers_per_type) as group_size. This gives at most +# num_types groups (one per spec type). Tensor positions without a layer of +# a given type are simply unused by that type — no memory is wasted because +# the physical tensor at each position is shared across types. +def _patch_kv_cache_grouping() -> None: + import vllm.v1.core.kv_cache_utils as _kcu + + _original = _kcu._get_kv_cache_groups_uniform_page_size + + def _patched(kv_cache_spec: dict) -> list: + from collections import defaultdict + from math import ceil as _ceil + + same_type_layers: defaultdict[object, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec].append(layer_name) + + min_num = min(len(v) for v in same_type_layers.values()) + max_num = max(len(v) for v in same_type_layers.values()) + + if min_num > 2 or max_num < min_num * 1.25: + # Not a singleton/near-singleton case — use original logic. + return _original(kv_cache_spec) + + # Singleton / near-singleton type detected. + # Use max_num as group_size to produce at most num_types groups. + group_size = max_num + apriel2_logger.info( + "KV cache grouping: using group_size=%d (max) instead of %d " + "(min) to avoid O(num_layers) groups with %d spec types", + group_size, + min_num, + len(same_type_layers), + ) + + grouped_layers = [] + for layers in same_type_layers.values(): + num_padding = group_size - len(layers) % group_size + if num_padding != group_size: + apriel2_logger.info( + "KV cache grouping: %d padding layers for type with %d real layers (group_size=%d)", + num_padding, + len(layers), + group_size, + ) + num_groups = _ceil(len(layers) / group_size) + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) + return _kcu.create_kv_cache_group_specs(kv_cache_spec, grouped_layers) + + _kcu._get_kv_cache_groups_uniform_page_size = _patched + + +_patch_kv_cache_grouping() # ============================================================================= @@ -757,13 +811,13 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: dtype=kv_cache_dtype, sliding_window=self.window_size, ) - else: - return FullAttentionSpec( - block_size=block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_dim, - dtype=kv_cache_dtype, - ) + + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + dtype=kv_cache_dtype, + ) class Apriel2MambaMixer(nn.Module): @@ -906,17 +960,25 @@ def stochastic_mixer_dispatch_fake( fake_impl=stochastic_mixer_dispatch_fake, ) -# Add stochastic_mixer_dispatch to vLLM's splitting ops so it causes graph breaks -# This allows dynamic dispatch at runtime even with CUDA graphs enabled +# Register custom ops as PIECEWISE splitting ops so they cause graph breaks. +# The Apriel2 GDN op MUST always be registered: even in FULL_AND_PIECEWISE mode, +# PIECEWISE graphs handle prefill. Without the graph break, the decode path gets +# baked into a compiled piece; prefill batches then replay the wrong path → +# illegal memory access. try: from vllm.config.compilation import CompilationConfig, CUDAGraphMode + _gdn_op = "vllm::apriel2_gdn_attention_core" + if _gdn_op not in CompilationConfig._attention_ops: + CompilationConfig._attention_ops.append(_gdn_op) + logger.info(f"Added {_gdn_op} to vLLM splitting ops") + _stochastic_op = "vllm::stochastic_mixer_dispatch" if _stochastic_op not in CompilationConfig._attention_ops: CompilationConfig._attention_ops.append(_stochastic_op) logger.info(f"Added {_stochastic_op} to vLLM splitting ops") except ImportError: - logger.warning("Could not add stochastic_mixer_dispatch to vLLM splitting ops") + logger.warning("Could not add custom ops to vLLM splitting ops") def _force_piecewise_cudagraph_for_stochastic_mixers(): @@ -966,7 +1028,7 @@ def fused_gdn_gating_kernel( g_ptr, beta_ptr, num_heads: tl.constexpr, - total_elements: tl.constexpr, + total_elements, BLOCK_SIZE: tl.constexpr, SOFTPLUS_THRESHOLD: tl.constexpr, ): @@ -1271,38 +1333,6 @@ def rearrange_mixed_qkv(self, mixed_qkv: torch.Tensor | None): value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query.contiguous(), key.contiguous(), value.contiguous() - def _debug_state_stats(self, name: str, state: torch.Tensor, seq_len: int): - """Debug recurrent state with statistics.""" - if not DEBUG_GDN_STATE or state is None: - return - flat = state.flatten() - first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) - print( - f"[vLLM-GDN {self.prefix}] {name} (seq_len={seq_len}): shape={state.shape}, " - f"mean={state.float().mean().item():.6f}, std={state.float().std().item():.6f}, " - f"min={state.float().min().item():.6f}, max={state.float().max().item():.6f}, " - f"first8=[{first8}]" - ) - - def _debug_tensor(self, name: str, t: torch.Tensor): - if not DEBUG_GDN_LAYER: - return - if t is None: - print(f"[GDN {self.prefix}] {name}: None") - return - flat = t.flatten()[:8] - vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) - print( - f"[GDN {self.prefix}] {name}: shape={t.shape}, dtype={t.dtype}, " - f"mean={t.float().mean().item():.6f}, std={t.float().std().item():.6f}, " - f"first8=[{vals}]" - ) - - def _debug_print(self, msg: str): - if not DEBUG_GDN_LAYER: - return - print(f"[GDN {self.prefix}] {msg}") - def forward( self, hidden_states: torch.Tensor, @@ -1313,30 +1343,17 @@ def forward( """Forward pass with custom op for core attention.""" num_tokens = hidden_states.size(0) - # self._cached_hidden_states = hidden_states # Cache for debug in _forward_core - # self._debug_print(f"===== FORWARD START (num_tokens={num_tokens}) =====") - # self._debug_tensor("hidden_states", hidden_states) - # Part 1: Input Projection projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_ba, _ = self.in_proj_ba(hidden_states) - # self._debug_tensor("projected_states_qkvz", projected_states_qkvz) - # self._debug_tensor("projected_states_ba", projected_states_ba) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) - # self._debug_tensor("query (after fix_ordering)", query) - # self._debug_tensor("key (after fix_ordering)", key) - # self._debug_tensor("value (after fix_ordering)", value) - # self._debug_tensor("z (after fix_ordering)", z) - # self._debug_tensor("b (after fix_ordering)", b) - # self._debug_tensor("a (after fix_ordering)", a) # Flatten heads: [tokens, heads, head_dim] -> [tokens, heads * head_dim] query = query.reshape(query.size(0), -1) key = key.reshape(key.size(0), -1) value = value.reshape(value.size(0), -1) mixed_qkv = torch.cat((query, key, value), dim=-1) - # self._debug_tensor("mixed_qkv (flattened)", mixed_qkv) # Part 2: Core Attention (Custom Op) core_attn_out = torch.zeros( @@ -1352,58 +1369,19 @@ def forward( core_attn_out, self.prefix, ) - # self._debug_tensor("core_attn_out (after custom op)", core_attn_out) # Part 3: Output Projection z_shape_og = z.shape core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - # self._debug_tensor("core_attn_out (before norm)", core_attn_out) - # self._debug_tensor("z (before norm)", z) - # Debug last token before norm (reshaped has tokens * heads rows) - if DEBUG_GDN_LAYER and num_tokens > 0: - num_heads = self.num_v_heads // self.tp_size - last_token_start = (num_tokens - 1) * num_heads - last_attn = core_attn_out[last_token_start : last_token_start + 1, :8] - last_z = z[last_token_start : last_token_start + 1, :8] - print( - f"[GDN {self.prefix}] core_attn_out before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_attn.flatten().float().tolist())}]" - ) - print( - f"[GDN {self.prefix}] z before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_z.flatten().float().tolist())}]" - ) - # self._debug_tensor("norm.weight", self.norm.weight) - # self._debug_print(f"norm.norm_before_gate={self.norm.norm_before_gate}, norm.eps={self.norm.eps}") core_attn_out = self.norm(core_attn_out, z) - # self._debug_tensor("core_attn_out (after norm)", core_attn_out) - # Debug last token after norm - if DEBUG_GDN_LAYER and num_tokens > 0: - last_attn_after = core_attn_out[last_token_start : last_token_start + 1, :8] - print( - f"[GDN {self.prefix}] core_attn_out after norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_attn_after.flatten().float().tolist())}]" - ) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") # Align dtype with projection weights (FA kernels may yield float32) target_dtype = self.out_proj.weight.dtype if core_attn_out.dtype != target_dtype: core_attn_out = core_attn_out.to(target_dtype) - # self._debug_tensor("core_attn_out (before out_proj)", core_attn_out) output[:num_tokens], _ = self.out_proj(core_attn_out) - # self._debug_tensor("output (final)", output[:num_tokens]) - # Show last token specifically - if DEBUG_GDN_LAYER: - last_token = output[num_tokens - 1, :8] - vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist()) - print(f"[GDN {self.prefix}] output (last token): last_token_first8=[{vals}]") - # Debug output hidden states during decode (num_tokens == 1) - if DEBUG_GDN_OUTPUT and num_tokens == 1: - flat = output[:num_tokens].flatten() - first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) - print( - f"[vLLM-GDN {self.prefix}] OUTPUT hs: shape={output[:num_tokens].shape}, mean={output[:num_tokens].float().mean().item():.6f}, std={output[:num_tokens].float().std().item():.6f}, first8=[{first8}]" - ) - # self._debug_print("===== FORWARD END =====") def _forward_core( self, @@ -1413,16 +1391,11 @@ def _forward_core( core_attn_out: torch.Tensor, ): """Core attention computation (called by custom op).""" - # self._debug_print("===== _forward_core START =====") - # self._debug_tensor("mixed_qkv (input to core)", mixed_qkv) - # self._debug_tensor("b (input to core)", b) - # self._debug_tensor("a (input to core)", a) forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: - # self._debug_print("attn_metadata is None, returning early") return assert isinstance(attn_metadata, dict) @@ -1436,35 +1409,35 @@ def _forward_core( non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor num_actual_tokens = attn_metadata.num_actual_tokens - # self._debug_print(f"num_actual_tokens={num_actual_tokens}, num_prefills={attn_metadata.num_prefills}, num_decodes={attn_metadata.num_decodes}") - # self._debug_print(f"has_initial_state={has_initial_state}") - # self._debug_print(f"non_spec_query_start_loc={non_spec_query_start_loc}") - self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - # self._debug_tensor("conv_state (from cache)", conv_state) - # self._debug_tensor("ssm_state (from cache)", ssm_state) + # During FULL CUDA graph capture the dummy batch size can exceed + # the number of allocated KV cache blocks (e.g. capture=512 but + # only 282 cache lines). Clamp to avoid out-of-bounds access in + # conv1d and recurrent kernels. + # IMPORTANT: Only clamp during CUDA graph capture. During normal + # inference, num_actual_tokens is the number of *tokens* in the + # batch, which can exceed num_blocks (= number of *state slots*) + # during prefill (many tokens share the same state slot). + # Clamping during prefill would truncate the input and cause the + # conv1d/recurrent kernels to read out-of-bounds via cache_indices + # that still reference the full batch. + num_cache_lines = conv_state.size(0) + if _compile_monitor.cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: + num_actual_tokens = num_cache_lines mixed_qkv = mixed_qkv[:num_actual_tokens] b = b[:num_actual_tokens] a = a[:num_actual_tokens] - # self._debug_tensor("mixed_qkv (truncated)", mixed_qkv) - # self._debug_tensor("b (truncated)", b) - # self._debug_tensor("a (truncated)", a) - # Convolution conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - # self._debug_tensor("conv_weights", conv_weights) - # self._debug_tensor("conv1d.bias", self.conv1d.bias) - # self._debug_print(f"activation={self.activation}") if attn_metadata.num_prefills > 0: - # self._debug_print("Using causal_conv1d_fn (prefill path)") mixed_qkv_T = mixed_qkv.transpose(0, 1) - # self._debug_tensor("mixed_qkv_T (before conv)", mixed_qkv_T) + mixed_qkv = causal_conv1d_fn( mixed_qkv_T, conv_weights, @@ -1477,8 +1450,7 @@ def _forward_core( metadata=attn_metadata, ).transpose(0, 1) else: - # self._debug_print("Using causal_conv1d_update (decode path)") - mixed_qkv = causal_conv1d_update( + causal_conv1d_update( mixed_qkv, conv_state, conv_weights, @@ -1488,49 +1460,21 @@ def _forward_core( validate_data=True, ) - # self._debug_tensor("mixed_qkv (after conv)", mixed_qkv) - query, key, value = self.rearrange_mixed_qkv(mixed_qkv) - # self._debug_tensor("query (after rearrange)", query) - # self._debug_tensor("key (after rearrange)", key) - # self._debug_tensor("value (after rearrange)", value) # Expand K heads to V heads for grouped query attention # (matches Fast-LLM and transformers reference implementations) # Always call repeat_interleave (no-op when value_heads_per_key == 1) to avoid # conditional branches that confuse torch.compile - # self._debug_print(f"Expanding K heads to V heads (value_heads_per_key={self.value_heads_per_key})") query = query.repeat_interleave(self.value_heads_per_key, dim=2) key = key.repeat_interleave(self.value_heads_per_key, dim=2) - # self._debug_tensor("query (after expand)", query) - # self._debug_tensor("key (after expand)", key) - # self._debug_tensor("A_log", self.A_log) - # self._debug_tensor("dt_bias", self.dt_bias) g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - # self._debug_tensor("g (from gating)", g) - # self._debug_tensor("beta (from gating)", beta) # Recurrent attention if attn_metadata.num_prefills > 0: - # self._debug_print("Using chunk_gated_delta_rule (prefill)") initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 - # self._debug_tensor("initial_state", initial_state) - # Debug PREFILL INPUTS before kernel call - if DEBUG_GDN_STATE: - print(f"[vLLM-GDN {self.prefix}] PREFILL INPUTS:") - print( - f" hidden_states: shape={self._cached_hidden_states.shape}, first8={self._cached_hidden_states.flatten()[:8].tolist()}" - ) - print(f" mixed_qkv (input): shape={mixed_qkv.shape}, first8={mixed_qkv.flatten()[:8].tolist()}") - print(f" q: shape={query.shape}, first8={query.flatten()[:8].tolist()}") - print(f" k: shape={key.shape}, first8={key.flatten()[:8].tolist()}") - print(f" v: shape={value.shape}, first8={value.flatten()[:8].tolist()}") - print(f" g: shape={g.shape}, first8={g.flatten()[:8].tolist()}") - print(f" beta: shape={beta.shape}, first8={beta.flatten()[:8].tolist()}") - print(f" initial_state: {initial_state}") - print(f" cu_seqlens: {non_spec_query_start_loc}") core_out, last_state = chunk_gated_delta_rule( q=query, k=key, @@ -1543,27 +1487,10 @@ def _forward_core( head_first=False, use_qk_l2norm_in_kernel=True, ) - # self._debug_tensor("core_out (from chunk_gated_delta_rule)", core_out) - # self._debug_tensor("last_state", last_state) - # # Debug prefill state - get seq_len from query_start_loc - # if non_spec_query_start_loc is not None and len(non_spec_query_start_loc) >= 2: - # prefill_seq_len = int(non_spec_query_start_loc[1] - non_spec_query_start_loc[0]) - # else: - # prefill_seq_len = num_actual_tokens - # self._debug_state_stats("PREFILL out_state", last_state, prefill_seq_len) ssm_state[non_spec_state_indices_tensor] = last_state.to(ssm_state.dtype) else: - # self._debug_print("Using fused_recurrent_gated_delta_rule (decode)") - # # For decode, access the correct slot using state indices - # if non_spec_state_indices_tensor is not None and len(non_spec_state_indices_tensor) > 0: - # slot_idx = int(non_spec_state_indices_tensor[0]) - # actual_state = ssm_state[slot_idx:slot_idx+1] - # # self._debug_state_stats("DECODE in_state", actual_state, num_actual_tokens) - # Debug decode inputs - if DEBUG_GDN_STATE: - print( - f"[vLLM-GDN {self.prefix}] DECODE inputs: q={query.flatten()[:4].tolist()}, k={key.flatten()[:4].tolist()}, v={value.flatten()[:4].tolist()}, g={g.flatten()[:4].tolist()}, beta={beta.flatten()[:4].tolist()}" - ) + # num_actual_tokens already clamped to cache lines above + num_decodes = min(attn_metadata.num_decodes, num_actual_tokens) core_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -1572,18 +1499,11 @@ def _forward_core( beta=beta, initial_state=ssm_state, inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc[: num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], use_qk_l2norm_in_kernel=True, ) - # self._debug_tensor("core_out (from fused_recurrent)", core_out) - # if non_spec_state_indices_tensor is not None and len(non_spec_state_indices_tensor) > 0: - # actual_state = ssm_state[slot_idx:slot_idx+1] - # # self._debug_state_stats("DECODE out_state", actual_state, num_actual_tokens) - core_attn_out[:num_actual_tokens] = core_out.squeeze(0)[:num_actual_tokens] - # self._debug_tensor("core_attn_out (final output)", core_attn_out) - # self._debug_print("===== _forward_core END =====") def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Checkpoint uses "convolution", model uses "conv1d" @@ -1605,10 +1525,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self.norm.weight.data.copy_(weight) loaded.add("norm.weight") elif name == "A_log": - self.A_log.data.copy_(weight) + self.A_log.weight_loader(self.A_log, weight) loaded.add("A_log") elif name == "dt_bias": - self.dt_bias.data.copy_(weight) + self.dt_bias.weight_loader(self.dt_bias, weight) loaded.add("dt_bias") return loaded @@ -1837,6 +1757,7 @@ def forward( dtype=hidden_states.dtype, device=hidden_states.device, ) + torch.ops.vllm.kda_attention( q, k, @@ -1874,17 +1795,25 @@ def _forward( num_actual_tokens = attn_metadata.num_actual_tokens constant_caches = self.kv_cache[forward_context.virtual_engine] + (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches + conv_state_q = conv_state_q.transpose(-1, -2) + conv_state_k = conv_state_k.transpose(-1, -2) + conv_state_v = conv_state_v.transpose(-1, -2) + + # During FULL CUDA graph capture the dummy batch size can exceed + # the number of allocated KV cache blocks. Clamp to avoid + # out-of-bounds access. + # IMPORTANT: Only clamp during CUDA graph capture — see GDN comment. + num_cache_lines = conv_state_q.size(0) + if _compile_monitor.cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: + num_actual_tokens = num_cache_lines + q_proj_states = q_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens] v_proj_states = v_proj_states[:num_actual_tokens] g1 = g1[:num_actual_tokens] beta = beta[:num_actual_tokens] - (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches - conv_state_q = conv_state_q.transpose(-1, -2) - conv_state_k = conv_state_k.transpose(-1, -2) - conv_state_v = conv_state_v.transpose(-1, -2) - q_conv_weights = self.q_conv1d.weight.view(self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)) k_conv_weights = self.k_conv1d.weight.view(self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)) v_conv_weights = self.v_conv1d.weight.view(self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)) @@ -1975,6 +1904,7 @@ def _forward( ) recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state else: + num_decodes = min(attn_metadata.num_decodes, num_actual_tokens) core_attn_out_non_spec, _ = fused_recurrent_kda( q=q, k=k, @@ -1983,8 +1913,8 @@ def _forward( beta=beta, initial_state=recurrent_state, use_qk_l2norm_in_kernel=True, - cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc[: num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], ) core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[0, :num_actual_tokens] @@ -2033,10 +1963,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self.g_b_proj.weight_loader(self.g_b_proj.weight, weight) loaded.add("g_b_proj.weight") elif name == "A_log": - self.A_log.data.copy_(weight) + self.A_log.weight_loader(self.A_log, weight) loaded.add("A_log") elif name == "dt_bias": - self.dt_bias.data.copy_(weight) + self.dt_bias.weight_loader(self.dt_bias, weight) loaded.add("dt_bias") return loaded @@ -2227,51 +2157,22 @@ def __init__( self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) - def _debug_tensor(self, name: str, t: torch.Tensor, show_last=False): - if not DEBUG_DECODER_LAYER or t is None: - return - if show_last: - # Show last token - last = t[-1, :8] if t.dim() == 2 else t[0, -1, :8] - vals = ", ".join(f"{v:.6f}" for v in last.float().tolist()) - print(f"[vLLM Layer] {name}: shape={t.shape}, last_token_first8=[{vals}]") - else: - flat = t.flatten()[:8] - vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) - print(f"[vLLM Layer] {name}: shape={t.shape}, first8=[{vals}]") - def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor | None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: - # self._debug_tensor("input hidden_states", hidden_states) - # self._debug_tensor("input residual", residual) - if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - # self._debug_tensor("after input_layernorm", hidden_states) - # self._debug_tensor("residual after input_layernorm", residual) - output = torch.empty_like(hidden_states) self.mixer(hidden_states, output) - # self._debug_tensor("mixer output", output) - hidden_states, residual = self.post_attention_layernorm(output, residual) - # self._debug_tensor("after post_attention_layernorm", hidden_states) - # self._debug_tensor("residual after post_attention_layernorm", residual) - hidden_states = self.mlp(hidden_states) - # self._debug_tensor("after mlp", hidden_states) - # Also show last token for final layer comparison - # self._debug_tensor("after mlp (last token)", hidden_states, show_last=True) - # self._debug_tensor("residual (last token)", residual, show_last=True) - return hidden_states, residual @@ -2758,33 +2659,7 @@ def forward( if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) - # Debug final norm - if DEBUG_FINAL_NORM: - # Show LAST token (to match TF) - last_hs = hidden_states[-1, :8] - last_res = residual[-1, :8] if residual is not None else None - hs_vals = ", ".join(f"{v:.6f}" for v in last_hs.float().tolist()) - res_vals = ", ".join(f"{v:.6f}" for v in last_res.float().tolist()) if last_res is not None else "None" - print( - f"[vLLM Final] hidden_states (before norm): shape={hidden_states.shape}, last_token_first8=[{hs_vals}]" - ) - print( - f"[vLLM Final] residual (before norm): shape={residual.shape if residual is not None else None}, last_token_first8=[{res_vals}]" - ) - print( - f"[vLLM Final] norm.weight: first8=[{', '.join(f'{v:.6f}' for v in self.norm.weight.flatten()[:8].float().tolist())}]" - ) - print(f"[vLLM Final] norm.variance_epsilon={self.norm.variance_epsilon}") - hidden_states, _ = self.norm(hidden_states, residual) - - if DEBUG_FINAL_NORM: - last_out = hidden_states[-1, :8] - out_vals = ", ".join(f"{v:.6f}" for v in last_out.float().tolist()) - print( - f"[vLLM Final] hidden_states (after norm): shape={hidden_states.shape}, last_token_first8=[{out_vals}]" - ) - return hidden_states @@ -2865,28 +2740,7 @@ def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: - # Debug LM head input - if DEBUG_LM_HEAD: - flat = hidden_states.flatten()[:8] - vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) - print(f"[vLLM LM Head] input hidden_states: shape={hidden_states.shape}, first8=[{vals}]") - if self.lm_head is not None: - lm_weight = self.lm_head.weight - print( - f"[vLLM LM Head] lm_head.weight: shape={lm_weight.shape}, first8=[{', '.join(f'{v:.6f}' for v in lm_weight.flatten()[:8].float().tolist())}]" - ) - logits = self.logits_processor(self.lm_head, hidden_states) - - if DEBUG_LM_HEAD and logits is not None: - # Get last token logits - last_logits = logits[-1] if logits.dim() == 2 else logits[0, -1] - top_vals, top_idx = last_logits.topk(5) - print(f"[vLLM LM Head] logits shape={logits.shape}") - print( - f"[vLLM LM Head] last token top-5 logits: {[(idx.item(), val.item()) for idx, val in zip(top_idx, top_vals)]}" - ) - return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: