Skip to content

Vllm modelling cleanup#483

Open
oleksost wants to merge 13 commits intofeature/vllm-apriel2-modelsfrom
oo/feature/vllm-apriel2-model-modeling
Open

Vllm modelling cleanup#483
oleksost wants to merge 13 commits intofeature/vllm-apriel2-modelsfrom
oo/feature/vllm-apriel2-model-modeling

Conversation

@oleksost
Copy link
Copy Markdown
Contributor

@oleksost oleksost commented Apr 9, 2026

Context:

  • 'oo/feature/vllm-aprie2-model-permixergraph/` branch was used for benchmarking throughout and obtaining some of the results for super Apriel (production tasks, i.e. non dev)
  • this is cleanup of the modeling file, which also includes some of the throughout related fixes (mostly relevant for cost model) described below

Tests (vllm 0.14rc1):

  • serving as full supernet using layer palcement code works well
  • serving in "surgery" mode, i.e. only predefined layout with plain vllm serve works well
  • currently making sure that we can reproduce the throughout numbers

Summary

Three categories of bug fixes, a KV cache throughput fix, and a general cleanup of the modeling file.

Bug fixes

  • GDN/KDA OOB crash during CUDA graph capture (9a6562d): During FULL CUDA graph capture vLLM creates dummy runs with batch sizes up to
    max_cudagraph_capture_size. When gpu_memory_utilization is low, the number of KV cache blocks can be smaller than the capture batch size, causing
    causal_conv1d_update and recurrent kernel assertions to fail. Fix: clamp num_actual_tokens to conv_state.size(0) only during CUDA graph capture
    (cudagraph_capturing_enabled). Also fixes cu_seqlens and ssm_state_indices slicing for decode paths in both GDN and KDA, and removes a spurious
    return-value assignment for the in-place causal_conv1d_update.
  • cudagraph_capturing_enabled import-by-value bug causing OOB crash on re-prefill (60541b0): from vllm.compilation.monitor import
    cudagraph_capturing_enabled captures the default value True at import time. When vLLM later calls set_cudagraph_capturing_enabled(False), the
    module-level name stays True, causing num_actual_tokens to be wrongly clamped during normal prefill — truncating inputs while query_start_loc still
    referenced the full batch, crashing causal_conv1d_fn. Fix: use vllm.compilation.monitor.cudagraph_capturing_enabled (module attribute access)
    everywhere.

KV cache grouping fix

  • Throughput regression fix for hybrid models with singleton mixer types (7a948a4): vLLM's _get_kv_cache_groups_uniform_page_size uses
    min(num_layers_per_type) as group size, creating O(num_layers) groups for models like a12_g1_k11 (where min=1). Each group triggers a metadata rebuild
    per forward step, causing >2× throughput regression. Monkey-patches the grouping function to use max(num_layers_per_type) when a
    singleton/near-singleton type is detected, reducing groups to O(num_types).
  • apriel2_gdn_attention_core always registered as PIECEWISE splitting op: Without the graph break, the decode path gets baked into a compiled prefill
    piece — prefill batches then replay the wrong path causing illegal memory access.
  • Apriel2Attention.get_kv_cache_spec: Removes spurious else: branch (cosmetic, same behavior).
  • fused_gdn_gating_kernel: Removes tl.constexpr from total_elements (was causing issues with variable batch sizes).

Cleanup

  • Removed ~800 lines of WIP code: per-mixer CUDA graph caching (MixerGraphCache), CPU offload for inactive mixers, APRIEL2_FULL_CUDA_GRAPHS /
    APRIEL2_MIXER_CUDA_GRAPHS / APRIEL2_OFFLOAD_INACTIVE_MIXERS env vars, and associated documentation.
  • Removed all debug artifacts: DEBUG_* flags, _debug_print/_debug_tensor/_debug_state_stats methods, commented-out debug calls, and CUDA sync debug
    points.

oleksost and others added 8 commits February 23, 2026 15:45
  Previously Apriel2 set is_hybrid=False and bypassed vLLM's config verification entirely because the default dispatch couldn't handle heterogeneous or pure-mamba
  compositions. These changes introduce a proper custom config handler that correctly routes to the right vLLM verification path based on the actual layer composition, and
  provides the get_mamba_state_shape/dtype methods needed for hybrid page size alignment.
During FULL CUDA graph capture, vLLM creates dummy runs with batch sizes
up to max_cudagraph_capture_size (e.g. 512). When gpu_memory_utilization
is low, the number of allocated KV cache blocks can be smaller than this
capture batch size (e.g. 282 blocks for 0.05 utilization). This caused
an assertion failure in causal_conv1d_update:
  assert num_cache_lines >= batch

The root cause: GDN and KDA decode paths used num_actual_tokens (which
equals the capture batch size during FULL graph capture) to slice
conv_state_indices and cu_seqlens, but the conv/recurrent state tensors
only have num_cache_blocks rows.

Fix: clamp num_actual_tokens to conv_state.size(0) at the top of both
GDN._forward_core() and KDA._forward(). This ensures all downstream
kernel calls (causal_conv1d_update, fused_recurrent_gated_delta_rule,
fused_recurrent_kda) receive batch sizes that fit within the allocated
cache. During normal inference the scheduler guarantees
num_actual_tokens <= cache blocks, so the clamp is a no-op.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…hybrid models

Three categories of changes for debugging and fixing throughput regression
in hybrid Apriel2 models with singleton mixer types (e.g. a12_g1_k11):

1. KV cache grouping monkey-patch: use max(num_layers_per_type) as group_size
   when min <= 2, reducing O(num_layers) groups to O(num_types) groups.

2. Always register apriel2_gdn_attention_core as PIECEWISE splitting op,
   fix fused_gdn_gating_kernel total_elements constexpr, and only clamp
   num_actual_tokens during CUDA graph capture (not during prefill).

3. Unconditional CUDA sync points (SYNC-1 through SYNC-6) with print(flush=True)
   to pinpoint illegal memory access during large re-prefill after preemption.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… on re-prefill

`from vllm.compilation.monitor import cudagraph_capturing_enabled` copies
the default value True at import time. When vLLM later calls
set_cudagraph_capturing_enabled(False), the module global changes but the
imported name stays True. This caused num_actual_tokens to be wrongly
clamped to num_cache_lines during inference (not just CUDA graph capture),
truncating mixed_qkv while query_start_loc still referenced the full
16384-token batch — making causal_conv1d_fn read OOB and crash.

Fix: use _compile_monitor.cudagraph_capturing_enabled (module attribute
access) for all runtime checks so the current value is always read.
Debug sync points (SYNC-1..6) kept behind APRIEL2_DEBUG_SYNC=1 flag.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@oleksost oleksost changed the title Oo/feature/vllm apriel2 model modeling Vllm modelling cleanup Apr 9, 2026
oleksost and others added 2 commits April 9, 2026 21:09
…erve bug fixes

Remove code added by the 5 intermediate WIP commits (per-mixer graph caching,
CPU offload for inactive mixers, FULL CUDA graph investigation, and associated
documentation):
- APRIEL2_FULL_CUDA_GRAPHS / APRIEL2_MIXER_CUDA_GRAPHS / APRIEL2_OFFLOAD_INACTIVE_MIXERS env vars
- MixerGraphCache class and _capture_all_mixers_for_num_tokens
- _move_module_device and CPU offload logic in StochasticMixer
- cudagraph_capturing_enabled check in stochastic_mixer_dispatch
- Extensive CUDA graph mode benchmark documentation (~300 lines)

Preserve all changes from the 3 bug-fix commits (9a6562d, 7a948a4, 60541b0):
- _patch_kv_cache_grouping() for heterogeneous hybrid model KV cache grouping
- import vllm.compilation.monitor as _compile_monitor (fix import-by-value bug)
- GDN/KDA clamp (cudagraph_capturing_enabled-conditional) to prevent OOB during capture
- causal_conv1d_update in-place fix (don't assign return value)
- num_decodes clamped to num_actual_tokens + ssm_state_indices[:num_actual_tokens]
- apriel2_gdn_attention_core always registered as PIECEWISE splitting op
- Attention.get_kv_cache_spec: remove spurious else branch
- fused_gdn_gating_kernel total_elements: remove tl.constexpr
- DEBUG_SYNC env var + SYNC-1..6 cuda sync debug points (gated behind APRIEL2_DEBUG_SYNC=1)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove all debug-only code that accumulated during development:
- DEBUG_* flag definitions (DEBUG_GDN_LAYER, DEBUG_GDN_STATE, DEBUG_GDN_OUTPUT,
  DEBUG_KDA_LAYER, DEBUG_DECODER_LAYER, DEBUG_FINAL_NORM, DEBUG_LM_HEAD, DEBUG_SYNC)
- _debug_print, _debug_tensor, _debug_state_stats helper methods
- All if DEBUG_*: blocks with print statements
- All commented-out # self._debug_* / # self._cached_* lines
- SYNC-1..6 CUDA sync debug points
- import os (only used by DEBUG_SYNC)

Functional code and bug fixes from previous commits are unchanged.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@oleksost oleksost force-pushed the oo/feature/vllm-apriel2-model-modeling branch from 9f4686b to 97b84d5 Compare April 9, 2026 21:10
oleksost and others added 3 commits April 9, 2026 21:10
These docs covered per-mixer CUDA graph caching and layer placement change
procedures — both tied to the WIP infrastructure that was removed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In Apriel2GatedDeltaNet and Apriel2KDAMixer, A_log and dt_bias
parameters were loaded with .data.copy_() which bypasses TP sharding.
This caused a size mismatch (e.g. 16 vs 32) when running with
tensor_parallel_size > 1. Switch to weight_loader() which correctly
handles the sharding.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds apriel2-vllm-plugin/pyproject.toml that installs only the vLLM
plugin entry point and its minimal dependencies (torch, transformers,
einops). This avoids pulling in the full Fast-LLM training framework.

Usage: pip install -e ./apriel2-vllm-plugin/

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants