Vllm modelling cleanup#483
Open
oleksost wants to merge 13 commits intofeature/vllm-apriel2-modelsfrom
Open
Conversation
Previously Apriel2 set is_hybrid=False and bypassed vLLM's config verification entirely because the default dispatch couldn't handle heterogeneous or pure-mamba compositions. These changes introduce a proper custom config handler that correctly routes to the right vLLM verification path based on the actual layer composition, and provides the get_mamba_state_shape/dtype methods needed for hybrid page size alignment.
During FULL CUDA graph capture, vLLM creates dummy runs with batch sizes up to max_cudagraph_capture_size (e.g. 512). When gpu_memory_utilization is low, the number of allocated KV cache blocks can be smaller than this capture batch size (e.g. 282 blocks for 0.05 utilization). This caused an assertion failure in causal_conv1d_update: assert num_cache_lines >= batch The root cause: GDN and KDA decode paths used num_actual_tokens (which equals the capture batch size during FULL graph capture) to slice conv_state_indices and cu_seqlens, but the conv/recurrent state tensors only have num_cache_blocks rows. Fix: clamp num_actual_tokens to conv_state.size(0) at the top of both GDN._forward_core() and KDA._forward(). This ensures all downstream kernel calls (causal_conv1d_update, fused_recurrent_gated_delta_rule, fused_recurrent_kda) receive batch sizes that fit within the allocated cache. During normal inference the scheduler guarantees num_actual_tokens <= cache blocks, so the clamp is a no-op. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…hybrid models Three categories of changes for debugging and fixing throughput regression in hybrid Apriel2 models with singleton mixer types (e.g. a12_g1_k11): 1. KV cache grouping monkey-patch: use max(num_layers_per_type) as group_size when min <= 2, reducing O(num_layers) groups to O(num_types) groups. 2. Always register apriel2_gdn_attention_core as PIECEWISE splitting op, fix fused_gdn_gating_kernel total_elements constexpr, and only clamp num_actual_tokens during CUDA graph capture (not during prefill). 3. Unconditional CUDA sync points (SYNC-1 through SYNC-6) with print(flush=True) to pinpoint illegal memory access during large re-prefill after preemption. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… on re-prefill `from vllm.compilation.monitor import cudagraph_capturing_enabled` copies the default value True at import time. When vLLM later calls set_cudagraph_capturing_enabled(False), the module global changes but the imported name stays True. This caused num_actual_tokens to be wrongly clamped to num_cache_lines during inference (not just CUDA graph capture), truncating mixed_qkv while query_start_loc still referenced the full 16384-token batch — making causal_conv1d_fn read OOB and crash. Fix: use _compile_monitor.cudagraph_capturing_enabled (module attribute access) for all runtime checks so the current value is always read. Debug sync points (SYNC-1..6) kept behind APRIEL2_DEBUG_SYNC=1 flag. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…erve bug fixes Remove code added by the 5 intermediate WIP commits (per-mixer graph caching, CPU offload for inactive mixers, FULL CUDA graph investigation, and associated documentation): - APRIEL2_FULL_CUDA_GRAPHS / APRIEL2_MIXER_CUDA_GRAPHS / APRIEL2_OFFLOAD_INACTIVE_MIXERS env vars - MixerGraphCache class and _capture_all_mixers_for_num_tokens - _move_module_device and CPU offload logic in StochasticMixer - cudagraph_capturing_enabled check in stochastic_mixer_dispatch - Extensive CUDA graph mode benchmark documentation (~300 lines) Preserve all changes from the 3 bug-fix commits (9a6562d, 7a948a4, 60541b0): - _patch_kv_cache_grouping() for heterogeneous hybrid model KV cache grouping - import vllm.compilation.monitor as _compile_monitor (fix import-by-value bug) - GDN/KDA clamp (cudagraph_capturing_enabled-conditional) to prevent OOB during capture - causal_conv1d_update in-place fix (don't assign return value) - num_decodes clamped to num_actual_tokens + ssm_state_indices[:num_actual_tokens] - apriel2_gdn_attention_core always registered as PIECEWISE splitting op - Attention.get_kv_cache_spec: remove spurious else branch - fused_gdn_gating_kernel total_elements: remove tl.constexpr - DEBUG_SYNC env var + SYNC-1..6 cuda sync debug points (gated behind APRIEL2_DEBUG_SYNC=1) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove all debug-only code that accumulated during development: - DEBUG_* flag definitions (DEBUG_GDN_LAYER, DEBUG_GDN_STATE, DEBUG_GDN_OUTPUT, DEBUG_KDA_LAYER, DEBUG_DECODER_LAYER, DEBUG_FINAL_NORM, DEBUG_LM_HEAD, DEBUG_SYNC) - _debug_print, _debug_tensor, _debug_state_stats helper methods - All if DEBUG_*: blocks with print statements - All commented-out # self._debug_* / # self._cached_* lines - SYNC-1..6 CUDA sync debug points - import os (only used by DEBUG_SYNC) Functional code and bug fixes from previous commits are unchanged. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
9f4686b to
97b84d5
Compare
These docs covered per-mixer CUDA graph caching and layer placement change procedures — both tied to the WIP infrastructure that was removed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In Apriel2GatedDeltaNet and Apriel2KDAMixer, A_log and dt_bias parameters were loaded with .data.copy_() which bypasses TP sharding. This caused a size mismatch (e.g. 16 vs 32) when running with tensor_parallel_size > 1. Switch to weight_loader() which correctly handles the sharding. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds apriel2-vllm-plugin/pyproject.toml that installs only the vLLM plugin entry point and its minimal dependencies (torch, transformers, einops). This avoids pulling in the full Fast-LLM training framework. Usage: pip install -e ./apriel2-vllm-plugin/ Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Context:
Tests (vllm 0.14rc1):
vllm serveworks wellSummary
Three categories of bug fixes, a KV cache throughput fix, and a general cleanup of the modeling file.
Bug fixes
max_cudagraph_capture_size. When gpu_memory_utilization is low, the number of KV cache blocks can be smaller than the capture batch size, causing
causal_conv1d_update and recurrent kernel assertions to fail. Fix: clamp num_actual_tokens to conv_state.size(0) only during CUDA graph capture
(cudagraph_capturing_enabled). Also fixes cu_seqlens and ssm_state_indices slicing for decode paths in both GDN and KDA, and removes a spurious
return-value assignment for the in-place causal_conv1d_update.
cudagraph_capturing_enabled captures the default value True at import time. When vLLM later calls set_cudagraph_capturing_enabled(False), the
module-level name stays True, causing num_actual_tokens to be wrongly clamped during normal prefill — truncating inputs while query_start_loc still
referenced the full batch, crashing causal_conv1d_fn. Fix: use vllm.compilation.monitor.cudagraph_capturing_enabled (module attribute access)
everywhere.
KV cache grouping fix
min(num_layers_per_type) as group size, creating O(num_layers) groups for models like a12_g1_k11 (where min=1). Each group triggers a metadata rebuild
per forward step, causing >2× throughput regression. Monkey-patches the grouping function to use max(num_layers_per_type) when a
singleton/near-singleton type is detected, reducing groups to O(num_types).
piece — prefill batches then replay the wrong path causing illegal memory access.
Cleanup
APRIEL2_MIXER_CUDA_GRAPHS / APRIEL2_OFFLOAD_INACTIVE_MIXERS env vars, and associated documentation.
points.