From 3bb49318cae7df529459af0c363700ab8da95929 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 23 Feb 2026 15:45:55 +0000 Subject: [PATCH 01/13] per mixer graph --- .../vllm/ai_docs/layer_placement_change.md | 80 +++++++ .../vllm/ai_docs/per_mixer_cuda_graphs.md | 105 +++++++++ .../apriel2/vllm/modeling_apriel2.py | 206 +++++++++++++++++- 3 files changed, 385 insertions(+), 6 deletions(-) create mode 100644 fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md create mode 100644 fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md diff --git a/fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md b/fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md new file mode 100644 index 000000000..ff4ae0335 --- /dev/null +++ b/fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md @@ -0,0 +1,80 @@ +# Layer Placement Change in Apriel2 vLLM + +## Architecture Overview + +Apriel2 is a **hybrid model** with heterogeneous decoder layers. Some layers are "stochastic" — they contain **multiple mixer sub-modules** (e.g., attention, GDN, KDA) but only run **one at a time**. The placement system lets you dynamically switch which mixer is active per layer at runtime. + +## Key Components + +### 1. `Apriel2StochasticMixer` (line ~2344) + +- Contains an `nn.ModuleDict` called `self.mixers` with all sub-mixer instances (e.g., attention, GDN, KDA) +- Tracks `self.active_mixer_name` — which mixer is currently active +- All sub-mixers have their weights loaded, but only one runs during forward pass +- Each sub-mixer gets a **virtual layer index** (`layer_idx + (mixer_index+1) * num_layers`) so they each get separate KV cache allocations without collisions + +### 2. `Apriel2StochasticDecoderLayer` (line ~2513) + +- Wraps `Apriel2StochasticMixer` + MLP + layer norms +- Exposes `set_active_mixer(name)` / `get_active_mixer()` which delegate to the mixer + +### 3. Dynamic dispatch via custom op (line ~870) + +- `stochastic_mixer_dispatch` is registered as a `vllm::stochastic_mixer_dispatch` custom op +- This op is added to vLLM's `_attention_ops` splitting ops list, causing **graph breaks** in torch.compile +- At runtime, it looks up the `Apriel2StochasticMixer` from `forward_context.no_compile_layers[layer_name]`, reads `active_mixer_name`, and forwards to that mixer +- The fake impl just copies input→output to satisfy the compiler's data dependency analysis + +## The Placement Change Call Chain + +From the debug script (`debug_offline.py`): + +```python +llm.collective_rpc("set_layer_placements", args=(placement,)) +``` + +1. **Worker monkey-patch** (line ~2962): `_patch_worker_for_placement_switching()` runs at import time and adds `set_layer_placements`/`get_layer_placements`/`get_mixer_names` methods to `vllm.v1.worker.gpu_worker.Worker` + +2. **`Worker._set_layer_placements`** (line ~3003): + - First calls `_clear_kv_cache(self)` — zeroes out **all** KV cache tensors to prevent stale data from a different mixer type causing NaN errors + - Then calls `self.get_model().set_layer_placements(placement)` + +3. **`Apriel2ForCausalLM.set_layer_placements`** (line ~2896): + - Iterates through all layers + - For each layer that is an `Apriel2StochasticDecoderLayer`, calls `layer.set_active_mixer(mixer_name)` with the corresponding entry from the placement list + +4. **`Apriel2StochasticMixer.set_active_mixer`** (line ~2454): + - Simply sets `self.active_mixer_name = name` (after validation) + +5. On the **next `llm.generate()` call**, the forward pass hits `stochastic_mixer_dispatch` which reads the updated `active_mixer_name` and routes to the new mixer. + +## Summary Diagram + +``` +debug_offline.py + | + +-- llm.collective_rpc("get_mixer_names") + | -> Worker.get_mixer_names -> model.get_mixer_names + | -> returns ("attention", "gdn", ...) from first stochastic layer + | + +-- llm.collective_rpc("get_layer_placements") + | -> Worker.get_layer_placements -> model.get_layer_placements + | -> returns {layer_idx: active_mixer_name} for stochastic layers + | + +-- llm.collective_rpc("set_layer_placements", args=(placement,)) + | -> Worker._set_layer_placements + | +-- _clear_kv_cache() <- zero all cache tensors + | +-- model.set_layer_placements(placement) + | +-- for each stochastic layer: + | layer.mixer.active_mixer_name = new_name + | + +-- llm.generate(prompts, ...) + -> forward pass per layer: + -> stochastic_mixer_dispatch (custom op, graph break) + -> looks up self.active_mixer_name + -> calls active_mixer.forward(hidden_states, output, positions) +``` + +## Key Insight + +All mixer weights are **always loaded** — switching is just flipping `active_mixer_name` and clearing the cache. The custom op mechanism ensures this dynamic routing works even with torch.compile/CUDA graphs by forcing graph breaks at dispatch points. diff --git a/fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md b/fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md new file mode 100644 index 000000000..f90327565 --- /dev/null +++ b/fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md @@ -0,0 +1,105 @@ +# Per-Mixer CUDA Graph Caching for Stochastic Mixers + +## Problem + +Apriel2's supernet uses `Apriel2StochasticMixer` — a wrapper that routes each layer to one of several mixer types (attention, GDN, KDA, sliding_window) based on the active placement. The `stochastic_mixer_dispatch` custom op is registered as a CUDA graph splitting op, forcing vLLM into **PIECEWISE** mode. + +In PIECEWISE mode, vLLM captures CUDA graphs for the compute pieces between split points (norms, MLPs) but runs the split points themselves eagerly. This means every mixer forward at every layer incurs full kernel launch overhead on every decode step. + +**Measured impact**: PIECEWISE mode achieves ~290 tok/s, while FULL CUDA graph mode with a fixed layout (no supernet) achieves ~900 tok/s — a **3x gap** from the split points running eagerly. + +## Approach + +Cache a separate `torch.cuda.CUDAGraph` per (mixer_name, num_tokens) at each stochastic layer. During decode, replay the cached graph instead of running the mixer eagerly. + +### Implementation (in `modeling_apriel2.py`) + +- **`APRIEL2_MIXER_CUDA_GRAPHS`** env var — gates the feature (default `"0"`) +- **`MixerGraphEntry`** — dataclass holding a captured graph + input/output pointer addresses +- **`MixerGraphCache`** — per-layer cache keyed by `(mixer_name, num_tokens)` +- **`_capture_all_mixers_for_num_tokens()`** — captures graphs during `capture_model()` with eager warmup before each capture (for Triton autotuning) +- **`_batch_has_prefill()`** — detects mixed prefill-decode batches that can't use graph replay +- **`stochastic_mixer_dispatch`** modified with capture/replay/eager-fallback logic +- Cache instance stored as `Apriel2StochasticMixer._mixer_graph_cache` + +### Dispatch Flow + +```text +stochastic_mixer_dispatch(hidden_states, output, positions, layer_name): + if cache is not None: + if capturing and runtime_mode == PIECEWISE: + → capture graphs for all/active mixers, return + if not prefill_batch and cache.has(active_mixer, num_tokens): + → cache.replay(), return + → eager fallback: active_mixer(hidden_states, output, positions) +``` + +## Bugs Encountered & Fixed + +### 1. CUBLAS_STATUS_NOT_INITIALIZED during profile_run + +`cudagraph_capturing_enabled` defaults to `True` in `vllm.compilation.monitor`. During `profile_run()` (before `capture_model()`), our code tried to capture graphs, but cuBLAS wasn't initialized yet. + +**Fix**: Gate capture on `runtime_mode != CUDAGraphMode.NONE` (NONE during profile_run, PIECEWISE during capture_model). + +### 2. Triton autotuning inside graph capture + +KDA's `fused_kda_gate` uses `@triton.autotune`. First call triggers benchmarking with `cuda.synchronize()` — illegal during stream capture. + +**Fix**: Run each mixer eagerly once before capturing (warmup triggers autotuning outside capture context). + +### 3. GPU memory pressure from too many captured graphs (CRITICAL) + +Capturing graphs for all mixers at all batch sizes creates ~5,040 graphs (48 layers x 3 mixers x ~35 batch sizes). This causes a **2.2x throughput regression** regardless of whether graphs are replayed. + +## Memory Pressure Investigation + +Systematic isolation of the regression source: + +| Test Configuration | Graphs | Warmup tok/s | vs Baseline | +| ------------------------------------------- | ------ | ------------ | ----------- | +| `CUDA_GRAPHS=0` (baseline) | 0 | 290 | 1.0x | +| `CUDA_GRAPHS=1`, cache exists but empty | 0 | 290 | 1.0x | +| `CUDA_GRAPHS=1`, active mixer only captured | ~1,680 | 179 | 0.62x | +| `CUDA_GRAPHS=1`, capture only (no replay) | ~5,040 | 132 | 0.46x | +| `CUDA_GRAPHS=1`, capture + replay | ~5,040 | 125 | 0.43x | +| `CUDA_GRAPHS=1`, private pool + replay | ~5,040 | 126 | 0.43x | + +**Key findings**: + +1. **Python overhead is negligible** — empty cache has zero impact (290 tok/s) +2. **Graph replay adds ~5% cost** — minimal compared to the capture overhead +3. **Private graph pool doesn't help** — total GPU memory consumption is the issue, not fragmentation of vLLM's global pool +4. **Regression scales with graph count** — 1,680 graphs = 0.62x, 5,040 = 0.43x +5. The captured graphs consume GPU memory that degrades all inference operations (likely L2 cache pressure, TLB misses, or reduced memory for temporary allocations) + +## Current State + +The implementation is functionally correct but the "capture everything upfront" strategy is not viable due to memory pressure. The code remains in `modeling_apriel2.py` gated behind `APRIEL2_MIXER_CUDA_GRAPHS=1` (disabled by default). + +## Proposed Next Approach: Lazy Per-Placement Capture + +Instead of capturing all mixers for all batch sizes during `capture_model()`: + +1. **On placement set**: capture graphs only for the active mixer at each layer, only for batch sizes actually encountered +2. **On placement change**: invalidate old cache (free GPU memory), re-capture for the new placement +3. **Lazy batch sizes**: capture on first encounter of a new batch size during decode, not upfront for all 35 sizes + +This would keep the graph count to ~48 (one per layer per active batch size), well within the safe memory budget. + +### Open Questions + +- **TP > 1 compatibility**: NCCL must be in graph-safe mode for captures involving collective ops. During `capture_model()` this is guaranteed; during inference it is not. Lazy capture may only be safe at TP=1. +- **Capture-during-inference feasibility**: Need to verify that `torch.cuda.graph()` capture works correctly when called from a piecewise split point during normal inference (not during `capture_model()`). +- **Warmup cost**: Each lazy capture requires an eager warmup (for Triton autotuning) + the capture itself. This adds latency to the first decode step after a placement change or new batch size. + +## Reference: vLLM Startup Phases + +```text +load_weights → profile_run() → allocate KV cache → capture_model() → inference + │ │ + │ cudagraph_capturing=True │ cudagraph_capturing=True + │ runtime_mode=NONE │ runtime_mode=PIECEWISE + │ cuBLAS NOT initialized │ cuBLAS initialized + │ DO NOT capture here │ Safe to capture +``` diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 5ab525f61..d65e95544 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -9,10 +9,11 @@ import logging import math +import os from collections.abc import Iterable from dataclasses import dataclass from itertools import islice -from typing import Literal +from typing import Callable, Literal import torch import triton @@ -23,8 +24,10 @@ from transformers.activations import ACT2FN from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.monitor import cudagraph_capturing_enabled, validate_cudagraph_capturing_enabled from vllm.config import CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul @@ -60,7 +63,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.selector import get_mamba_attn_backend @@ -102,6 +105,103 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i DEBUG_FINAL_NORM = False # Debug final norm before LM head DEBUG_LM_HEAD = False # Debug LM head input/output +# Per-mixer CUDA graph caching: capture a separate graph per sub-mixer variant +# per batch size, replayed during decode instead of running the mixer eagerly. +APRIEL2_MIXER_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_MIXER_CUDA_GRAPHS", "0") == "1" + + +# ============================================================================= +# Per-Mixer CUDA Graph Cache +# ============================================================================= + + +@dataclass +class MixerGraphEntry: + """A captured CUDA graph for one mixer at one batch size.""" + + graph: torch.cuda.CUDAGraph + input_ptr: int # hidden_states.data_ptr() at capture time + output_ptr: int # output.data_ptr() at capture time + + +class MixerGraphCache: + """Per-mixer, per-batch-size CUDA graph cache for stochastic mixers. + + Each entry is keyed by (mixer_name, num_tokens). Capture happens during + vLLM's capture_model() phase (when cudagraph_capturing_enabled is True + and NCCL is in graph-safe mode). Replay happens during normal decode. + """ + + def __init__(self) -> None: + self._entries: dict[tuple[str, int], MixerGraphEntry] = {} + # Use a PRIVATE graph pool to avoid fragmenting vLLM's global pool. + # The global pool is shared with piecewise graph pieces; interleaving + # our ~5000 mixer captures with vLLM's piece captures degrades + # piecewise replay performance ~2x due to pool fragmentation. + self._graph_pool = torch.cuda.graph_pool_handle() + + def has(self, mixer_name: str, num_tokens: int) -> bool: + return (mixer_name, num_tokens) in self._entries + + def capture( + self, + mixer_name: str, + num_tokens: int, + mixer_fn: Callable, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor | None, + ) -> None: + """Capture a CUDA graph for mixer_fn with the given inputs. + + Must be called during vLLM's capture_model() phase. + """ + validate_cudagraph_capturing_enabled() + key = (mixer_name, num_tokens) + if key in self._entries: + return + + graph = torch.cuda.CUDAGraph() + if self._graph_pool is not None: + set_graph_pool_id(self._graph_pool) + + with torch.cuda.graph(graph, pool=self._graph_pool, stream=current_stream()): + mixer_fn(hidden_states, output, positions=positions) + + self._entries[key] = MixerGraphEntry( + graph=graph, + input_ptr=hidden_states.data_ptr(), + output_ptr=output.data_ptr(), + ) + apriel2_logger.debug( + "MixerGraphCache: captured graph for (%s, %d), total=%d", + mixer_name, + num_tokens, + len(self._entries), + ) + + def replay( + self, + mixer_name: str, + num_tokens: int, + hidden_states: torch.Tensor, + output: torch.Tensor, + ) -> None: + """Replay the cached graph. Asserts pointer stability in debug mode.""" + entry = self._entries[(mixer_name, num_tokens)] + if __debug__: + assert hidden_states.data_ptr() == entry.input_ptr, ( + f"MixerGraphCache: hidden_states pointer changed between " + f"capture (0x{entry.input_ptr:x}) and replay " + f"(0x{hidden_states.data_ptr():x}) for ({mixer_name}, {num_tokens})" + ) + assert output.data_ptr() == entry.output_ptr, ( + f"MixerGraphCache: output pointer changed between " + f"capture (0x{entry.output_ptr:x}) and replay " + f"(0x{output.data_ptr():x}) for ({mixer_name}, {num_tokens})" + ) + entry.graph.replay() + # ============================================================================= # KV Cache Spec Computation @@ -867,20 +967,111 @@ def apriel2_gdn_attention_core_fake( # ============================================================================= +def _batch_has_prefill(forward_context: ForwardContext, active_mixer: nn.Module) -> bool: + """Return True if this batch contains prefill tokens. + + CUDA graphs captured during decode cannot handle prefill, so we must + fall back to eager execution for mixed batches. + """ + from vllm.config.compilation import CUDAGraphMode + + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.NONE: + return True # Profile/warmup run — treat as non-graphable + + attn_meta = forward_context.attn_metadata + if isinstance(attn_meta, dict): + mixer_prefix = getattr(active_mixer, "prefix", None) + if mixer_prefix is not None: + meta = attn_meta.get(mixer_prefix) + if meta is not None and hasattr(meta, "num_prefills"): + return meta.num_prefills > 0 + return False + + +def _capture_all_mixers_for_num_tokens( + stochastic_mixer: "Apriel2StochasticMixer", + cache: MixerGraphCache, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor | None, + num_tokens: int, +) -> None: + """Capture CUDA graphs for ALL sub-mixers at this batch size. + + Called during vLLM's capture_model() phase. All mixers are captured + writing into the same ``output`` buffer address so that any mixer's + graph can be replayed into the current output on future decode steps. + + After capturing, the active mixer is run eagerly once to leave the + correct result in ``output`` for downstream layers. + """ + active_name = stochastic_mixer.active_mixer_name + + for mixer_name, mixer in stochastic_mixer.mixers.items(): + if cache.has(mixer_name, num_tokens): + continue + + # Eager warmup: run the mixer once outside graph capture to trigger + # Triton autotuning (which calls cuda.synchronize() internally). + # Without this, autotuning during torch.cuda.graph() capture causes + # "operation not permitted when stream is capturing". + torch.cuda.synchronize() + output.zero_() + mixer(hidden_states, output, positions=positions) + + torch.cuda.synchronize() + output.zero_() + + apriel2_logger.info( + "Capturing mixer CUDA graph: layer=%s mixer=%s num_tokens=%d", + stochastic_mixer.prefix, + mixer_name, + num_tokens, + ) + cache.capture(mixer_name, num_tokens, mixer, hidden_states, output, positions) + + # Restore output to the active mixer's result for the outer capture pass + torch.cuda.synchronize() + stochastic_mixer.mixers[active_name](hidden_states, output, positions=positions) + + def stochastic_mixer_dispatch( hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor | None, layer_name: str, ) -> None: - """Dispatch to the active mixer at runtime (escapes torch.compile).""" + """Dispatch to the active mixer; capture or replay CUDA graphs when enabled.""" forward_context: ForwardContext = get_forward_context() stochastic_mixer = forward_context.no_compile_layers[layer_name] + active_name: str = stochastic_mixer.active_mixer_name + active_mixer = stochastic_mixer.mixers[active_name] + cache: MixerGraphCache | None = stochastic_mixer._mixer_graph_cache + + if cache is not None: + num_tokens = hidden_states.shape[0] + + # Capture phase: capture graphs for ALL mixers at this batch size. + # We must check runtime_mode to avoid capturing during profile_run, + # where cudagraph_capturing_enabled is True but cuBLAS hasn't been + # lazily initialized yet (CUBLAS_STATUS_NOT_INITIALIZED). + # During profile_run, runtime_mode is NONE; during capture_model() + # it is PIECEWISE. + from vllm.config.compilation import CUDAGraphMode + + runtime_mode = forward_context.cudagraph_runtime_mode + if cudagraph_capturing_enabled and runtime_mode is not None and runtime_mode != CUDAGraphMode.NONE: + _capture_all_mixers_for_num_tokens(stochastic_mixer, cache, hidden_states, output, positions, num_tokens) + return - # Get the currently active mixer (runtime lookup) - active_mixer = stochastic_mixer.mixers[stochastic_mixer.active_mixer_name] + # Replay phase: use cached graph if available and batch is decode-only + has_prefill = _batch_has_prefill(forward_context, active_mixer) + has_cached = cache.has(active_name, num_tokens) + if not has_prefill and has_cached: + cache.replay(active_name, num_tokens, hidden_states, output) + return - # Forward through the active mixer + # Eager fallback active_mixer(hidden_states, output, positions=positions) @@ -2451,6 +2642,9 @@ def __init__( # FULL mode captures only active mixer ops, breaking dormant mixer switching _force_piecewise_cudagraph_for_stochastic_mixers() + # Per-mixer CUDA graph cache (populated during capture_model() phase) + self._mixer_graph_cache: MixerGraphCache | None = MixerGraphCache() if APRIEL2_MIXER_CUDA_GRAPHS else None + def set_active_mixer(self, name: str) -> None: """Set the active mixer by name.""" if name not in self.mixers: From f6c11c2f03d559e495b3fb4158e5c1ba70458c54 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 23 Feb 2026 19:51:01 +0000 Subject: [PATCH 02/13] documentation --- .../apriel2/vllm/modeling_apriel2.py | 184 ++++++++++++++++-- 1 file changed, 165 insertions(+), 19 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index d65e95544..0a0a14bbb 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -105,8 +105,104 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i DEBUG_FINAL_NORM = False # Debug final norm before LM head DEBUG_LM_HEAD = False # Debug LM head input/output -# Per-mixer CUDA graph caching: capture a separate graph per sub-mixer variant -# per batch size, replayed during decode instead of running the mixer eagerly. +# ============================================================================= +# CUDA Graph Modes for Stochastic Mixers +# ============================================================================= +# +# The Apriel2StochasticMixer wraps multiple sub-mixers (attention, GDN, KDA) +# per layer and routes to the active one at runtime. This interacts with +# vLLM's CUDA graph capture in several modes, controlled by env vars. +# +# Benchmark setup: 10 concurrent requests, all-attention layout, prompt +# length 1, max generation length 16k, REST backend, no warmup after +# local vLLM launch, single H100 80GB. +# +# ┌──────────────────────────────────────────────────────────────────────┐ +# │ Mode 0: Fixed-layout FULL graphs (baseline, no supernet) │ +# ├──────────────────────────────────────────────────────────────────────┤ +# │ Serve a checkpoint with a predefined layout (single mixer per │ +# │ layer). Standard vLLM FULL CUDA graph capture — no stochastic │ +# │ dispatch involved. │ +# │ │ +# │ Weights: 26.91 GiB KV cache: 43.50 GiB Graphs: 0.10 GiB │ +# │ Throughput: 583 tok/s │ +# │ │ +# │ This is the upper bound — all mixer weights are collapsed into one │ +# │ per layer, leaving maximum memory for KV cache. │ +# └──────────────────────────────────────────────────────────────────────┘ +# +# ┌──────────────────────────────────────────────────────────────────────┐ +# │ Mode 1: Supernet + FULL graphs (APRIEL2_FULL_CUDA_GRAPHS=1) │ +# │ [default] │ +# ├──────────────────────────────────────────────────────────────────────┤ +# │ vLLM captures the entire forward as one CUDA graph per batch size. │ +# │ stochastic_mixer_dispatch is NOT a graph-splitting op: during │ +# │ capture it calls the active mixer, baking its GPU kernels into the │ +# │ graph. On replay the same kernels execute. │ +# │ │ +# │ On layout change (set_layer_placements), all captured graphs are │ +# │ invalidated and re-captured via capture_model() (~5-15 s). │ +# │ │ +# │ Weights: 45.92 GiB KV cache: 24.48 GiB Graphs: 0.09 GiB │ +# │ Throughput: ~200 tok/s │ +# │ │ +# │ All 48×4 = 192 sub-mixer weights are loaded, consuming 19 GiB │ +# │ more than fixed-layout. The fixed-layout model has 1.78× more KV │ +# │ cache (43.5 vs 24.5 GiB), allowing far more tokens in-flight — │ +# │ this is the primary cause of the throughput gap to Mode 0. │ +# └──────────────────────────────────────────────────────────────────────┘ +# +# ┌──────────────────────────────────────────────────────────────────────┐ +# │ Mode 2: Supernet + PIECEWISE + per-mixer sub-graphs │ +# │ (APRIEL2_FULL_CUDA_GRAPHS=0, APRIEL2_MIXER_CUDA_GRAPHS=1) │ +# ├──────────────────────────────────────────────────────────────────────┤ +# │ stochastic_mixer_dispatch is a graph-splitting op → vLLM forces │ +# │ PIECEWISE mode. At each dispatch point, a separate small CUDA │ +# │ graph is cached per (mixer, batch_size) and replayed. │ +# │ │ +# │ No re-capture needed on layout change (dispatch selects mixer at │ +# │ runtime), but creates ~5k graphs causing GPU memory pressure. │ +# │ │ +# │ Weights: 45.92 GiB KV cache: 24.48 GiB Graphs: 2.43 GiB │ +# │ Capture time: ~45 s │ +# │ Throughput: ~155 tok/s │ +# │ │ +# │ WARNING: Not recommended. The graph memory further reduces │ +# │ available KV cache and capture overhead is substantial. │ +# └──────────────────────────────────────────────────────────────────────┘ +# +# ┌──────────────────────────────────────────────────────────────────────┐ +# │ Mode 3: Supernet + PIECEWISE + eager dispatch │ +# │ (APRIEL2_FULL_CUDA_GRAPHS=0, APRIEL2_MIXER_CUDA_GRAPHS=0) │ +# ├──────────────────────────────────────────────────────────────────────┤ +# │ Same as Mode 2 but mixer forward runs fully eagerly (no per-mixer │ +# │ graph caching). Graph breaks at every dispatch, Python selects the │ +# │ active mixer each step. │ +# │ │ +# │ No re-capture needed on layout change. │ +# │ │ +# │ Weights: 45.92 GiB KV cache: 24.48 GiB Graphs: ~0 GiB │ +# │ Throughput: 151 tok/s │ +# └──────────────────────────────────────────────────────────────────────┘ +# +# Summary (H100 80 GB, all-attention layout, 10 concurrent reqs, 16k output, prompt length 1): +# +# Mode │ Supernet │ FULL │ Per-mixer subgraph │ Weights │ KV cache │ Graphs │ tok/s +# ─────┼──────────┼──────┼────────────────────┼─────────┼──────────┼────────┼────── +# 0 │ no │ on │ - │ 26.9 Gi │ 43.5 Gi │ 0.1 Gi │ 583 +# 1 │ yes │ on │ - │ 45.9 Gi │ 24.5 Gi │ 0.1 Gi │ ~200 +# 2 │ yes │ off │ on │ 45.9 Gi │ 24.5 Gi │ 2.4 Gi │ ~155 +# 3 │ yes │ off │ off │ 45.9 Gi │ 24.5 Gi │ ~0 Gi │ ~151 +# +# FULL = APRIEL2_FULL_CUDA_GRAPHS +# Per-mixer subgraph = APRIEL2_MIXER_CUDA_GRAPHS +# +# Note: CUDA graph capture is essential for linear mixers (GDN, KDA). +# They will be slower than attention in Mode 3 (eager dispatch) but +# can be faster than attention in Mode 1 (FULL graphs). +# +# ============================================================================= +APRIEL2_FULL_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_FULL_CUDA_GRAPHS", "1") == "1" APRIEL2_MIXER_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_MIXER_CUDA_GRAPHS", "0") == "1" @@ -1097,17 +1193,19 @@ def stochastic_mixer_dispatch_fake( fake_impl=stochastic_mixer_dispatch_fake, ) -# Add stochastic_mixer_dispatch to vLLM's splitting ops so it causes graph breaks -# This allows dynamic dispatch at runtime even with CUDA graphs enabled -try: - from vllm.config.compilation import CompilationConfig, CUDAGraphMode +# Add stochastic_mixer_dispatch to vLLM's splitting ops so it causes graph breaks. +# Only needed in PIECEWISE mode — in FULL mode the dispatch is transparent to +# graph capture (the active mixer's kernels get baked into the full graph). +if not APRIEL2_FULL_CUDA_GRAPHS: + try: + from vllm.config.compilation import CompilationConfig, CUDAGraphMode - _stochastic_op = "vllm::stochastic_mixer_dispatch" - if _stochastic_op not in CompilationConfig._attention_ops: - CompilationConfig._attention_ops.append(_stochastic_op) - logger.info(f"Added {_stochastic_op} to vLLM splitting ops") -except ImportError: - logger.warning("Could not add stochastic_mixer_dispatch to vLLM splitting ops") + _stochastic_op = "vllm::stochastic_mixer_dispatch" + if _stochastic_op not in CompilationConfig._attention_ops: + CompilationConfig._attention_ops.append(_stochastic_op) + logger.info(f"Added {_stochastic_op} to vLLM splitting ops") + except ImportError: + logger.warning("Could not add stochastic_mixer_dispatch to vLLM splitting ops") def _force_piecewise_cudagraph_for_stochastic_mixers(): @@ -2638,12 +2736,15 @@ def __init__( raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - # Force PIECEWISE cudagraph mode for stochastic mixers - # FULL mode captures only active mixer ops, breaking dormant mixer switching - _force_piecewise_cudagraph_for_stochastic_mixers() - - # Per-mixer CUDA graph cache (populated during capture_model() phase) - self._mixer_graph_cache: MixerGraphCache | None = MixerGraphCache() if APRIEL2_MIXER_CUDA_GRAPHS else None + if APRIEL2_FULL_CUDA_GRAPHS: + # FULL mode: the entire forward (including dispatch) is captured as + # one CUDA graph. On layout change, graphs are re-captured. + # No per-mixer cache needed — vLLM's CUDAGraphWrapper handles it. + self._mixer_graph_cache: MixerGraphCache | None = None + else: + # PIECEWISE mode: force graph break at dispatch, mixer runs eagerly + _force_piecewise_cudagraph_for_stochastic_mixers() + self._mixer_graph_cache = MixerGraphCache() if APRIEL2_MIXER_CUDA_GRAPHS else None def set_active_mixer(self, name: str) -> None: """Set the active mixer by name.""" @@ -3204,11 +3305,56 @@ def _clear_kv_cache(self) -> None: logger.info("Cleared KV cache tensors for placement switch") + def _clear_piecewise_wrappers(module: nn.Module) -> None: + """Recursively clear all piecewise CUDAGraphWrapper entries.""" + from vllm.compilation.cuda_graph import CUDAGraphWrapper + + for val in module.__dict__.values(): + if isinstance(val, CUDAGraphWrapper): + val.concrete_cudagraph_entries.clear() + elif isinstance(val, nn.Module): + _clear_piecewise_wrappers(val) + for child in module.children(): + _clear_piecewise_wrappers(child) + + def _recapture_cuda_graphs(worker) -> None: + """Invalidate and re-capture CUDA graphs after layout change. + + In FULL CUDA graph mode, the captured graphs contain GPU kernels for + the previous layout. After changing mixer assignments, we must + re-capture to bake in the new layout's kernels. + """ + model_runner = getattr(worker, "model_runner", None) + if model_runner is None: + return + + from vllm.compilation.cuda_graph import CUDAGraphWrapper + + # 1. Clear outer FULL wrapper entries + model = model_runner.model + if isinstance(model, CUDAGraphWrapper): + num_cleared = len(model.concrete_cudagraph_entries) + model.concrete_cudagraph_entries.clear() + logger.info(f"Cleared {num_cleared} FULL CUDA graph entries") + + # 2. Clear inner piecewise wrapper entries (for FULL_AND_PIECEWISE mode) + inner_model = model.unwrap() if isinstance(model, CUDAGraphWrapper) else model + _clear_piecewise_wrappers(inner_model) + + # 3. Re-capture for all batch sizes + logger.info("Re-capturing CUDA graphs for new layout...") + model_runner.capture_model() + logger.info("CUDA graph re-capture complete") + def _set_layer_placements(self, placement: list[str]) -> dict[int, str]: # Clear KV cache BEFORE changing placement to prevent reading stale data # written by a different mixer type (which could cause NaN errors) _clear_kv_cache(self) - return self.get_model().set_layer_placements(placement) + result = self.get_model().set_layer_placements(placement) + # Re-capture CUDA graphs with the new layout baked in + if APRIEL2_FULL_CUDA_GRAPHS and result: + _recapture_cuda_graphs(self) + return result def _get_mixer_names(self) -> tuple[str, ...]: return self.get_model().get_mixer_names() From 2dd41610a61e579000bb085e7112c591bc6e0e78 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 23 Feb 2026 22:48:28 +0000 Subject: [PATCH 03/13] offload not active mixers to cpu --- .../apriel2/vllm/modeling_apriel2.py | 193 ++++++++++++++++-- 1 file changed, 178 insertions(+), 15 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 0a0a14bbb..7cee923ed 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -132,7 +132,8 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i # └──────────────────────────────────────────────────────────────────────┘ # # ┌──────────────────────────────────────────────────────────────────────┐ -# │ Mode 1: Supernet + FULL graphs (APRIEL2_FULL_CUDA_GRAPHS=1) │ +# │ Mode 1: Supernet + FULL graphs + weight offload │ +# │ (APRIEL2_FULL_CUDA_GRAPHS=1, APRIEL2_OFFLOAD_INACTIVE_MIXERS=1)│ # │ [default] │ # ├──────────────────────────────────────────────────────────────────────┤ # │ vLLM captures the entire forward as one CUDA graph per batch size. │ @@ -140,16 +141,32 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i # │ capture it calls the active mixer, baking its GPU kernels into the │ # │ graph. On replay the same kernels execute. │ # │ │ -# │ On layout change (set_layer_placements), all captured graphs are │ -# │ invalidated and re-captured via capture_model() (~5-15 s). │ +# │ After profile_run(), inactive mixer weights are offloaded to CPU. │ +# │ This reduces GPU weight memory to ~26.9 GiB (matching Mode 0) and │ +# │ frees ~19 GiB for KV cache. Only parameters are moved — shared │ +# │ buffers (e.g. RotaryEmbedding cos_sin_cache) stay on GPU. │ +# │ Offloaded mixers are also removed from nn.ModuleDict to avoid │ +# │ torch.compile guard invalidation. │ +# │ │ +# │ On layout change (set_layer_placements): │ +# │ 1. KV cache is cleared │ +# │ 2. Weights are swapped layer by layer (old→CPU, new→GPU) │ +# │ 3. All captured CUDA graphs are invalidated and re-captured │ +# │ via capture_model() (~5-15 s) │ +# │ │ +# │ Weights: 26.91 GiB KV cache: 43.50 GiB Graphs: 0.07 GiB │ +# │ Throughput: 516 tok/s │ +# └──────────────────────────────────────────────────────────────────────┘ +# +# ┌──────────────────────────────────────────────────────────────────────┐ +# │ Mode 1b: Supernet + FULL graphs, no offload │ +# │ (APRIEL2_FULL_CUDA_GRAPHS=1, APRIEL2_OFFLOAD_INACTIVE_MIXERS=0)│ +# ├──────────────────────────────────────────────────────────────────────┤ +# │ Same as Mode 1 but all mixer weights stay on GPU. This wastes │ +# │ ~19 GiB on inactive mixers, leaving 1.78× less KV cache. │ # │ │ # │ Weights: 45.92 GiB KV cache: 24.48 GiB Graphs: 0.09 GiB │ # │ Throughput: ~200 tok/s │ -# │ │ -# │ All 48×4 = 192 sub-mixer weights are loaded, consuming 19 GiB │ -# │ more than fixed-layout. The fixed-layout model has 1.78× more KV │ -# │ cache (43.5 vs 24.5 GiB), allowing far more tokens in-flight — │ -# │ this is the primary cause of the throughput gap to Mode 0. │ # └──────────────────────────────────────────────────────────────────────┘ # # ┌──────────────────────────────────────────────────────────────────────┐ @@ -187,14 +204,16 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i # # Summary (H100 80 GB, all-attention layout, 10 concurrent reqs, 16k output, prompt length 1): # -# Mode │ Supernet │ FULL │ Per-mixer subgraph │ Weights │ KV cache │ Graphs │ tok/s -# ─────┼──────────┼──────┼────────────────────┼─────────┼──────────┼────────┼────── -# 0 │ no │ on │ - │ 26.9 Gi │ 43.5 Gi │ 0.1 Gi │ 583 -# 1 │ yes │ on │ - │ 45.9 Gi │ 24.5 Gi │ 0.1 Gi │ ~200 -# 2 │ yes │ off │ on │ 45.9 Gi │ 24.5 Gi │ 2.4 Gi │ ~155 -# 3 │ yes │ off │ off │ 45.9 Gi │ 24.5 Gi │ ~0 Gi │ ~151 +# Mode │ Supernet │ FULL │ Offload │ Per-mixer subgraph │ Weights │ KV cache │ Graphs │ tok/s +# ─────┼──────────┼──────┼─────────┼────────────────────┼─────────┼──────────┼────────┼────── +# 0 │ no │ on │ - │ - │ 26.9 Gi │ 43.5 Gi │ 0.1 Gi │ 583 +# 1 │ yes │ on │ on │ - │ 26.9 Gi │ 43.5 Gi │ 0.1 Gi │ 516 +# 1b │ yes │ on │ off │ - │ 45.9 Gi │ 24.5 Gi │ 0.1 Gi │ ~200 +# 2 │ yes │ off │ off │ on │ 45.9 Gi │ 24.5 Gi │ 2.4 Gi │ ~155 +# 3 │ yes │ off │ off │ off │ 45.9 Gi │ 24.5 Gi │ ~0 Gi │ ~151 # # FULL = APRIEL2_FULL_CUDA_GRAPHS +# Offload = APRIEL2_OFFLOAD_INACTIVE_MIXERS # Per-mixer subgraph = APRIEL2_MIXER_CUDA_GRAPHS # # Note: CUDA graph capture is essential for linear mixers (GDN, KDA). @@ -205,6 +224,29 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i APRIEL2_FULL_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_FULL_CUDA_GRAPHS", "1") == "1" APRIEL2_MIXER_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_MIXER_CUDA_GRAPHS", "0") == "1" +# Offload inactive mixer weights to CPU after profile_run(). Frees ~19 GiB +# GPU memory for KV cache. On layout switch, weights are swapped layer by layer. +# Note: cannot offload during load_weights() because torch.compile captures all +# parameters as graph inputs — profile_run() would crash with CPU tensors. +# Default: enabled with FULL CUDA graphs (layout is fixed per graph capture). +APRIEL2_OFFLOAD_INACTIVE_MIXERS: bool = ( + os.environ.get("APRIEL2_OFFLOAD_INACTIVE_MIXERS", "1" if APRIEL2_FULL_CUDA_GRAPHS else "0") == "1" +) + + +def _move_module_device(module: nn.Module, device: torch.device) -> None: + """Move a module's weight parameters to a device. + + Uses param.data assignment (not module.to()) to preserve vLLM's + BasevLLMParameter metadata (_weight_loader attribute). + + Only moves parameters, NOT buffers. Buffers like RotaryEmbedding's + cos_sin_cache are shared across mixer instances (via get_rope() LRU cache). + Moving them would corrupt the active mixer's shared state. + """ + for param in module.parameters(): + param.data = param.data.to(device, non_blocking=True) + # ============================================================================= # Per-Mixer CUDA Graph Cache @@ -2748,7 +2790,7 @@ def __init__( def set_active_mixer(self, name: str) -> None: """Set the active mixer by name.""" - if name not in self.mixers: + if name not in self._mixer_names: raise ValueError(f"Unknown mixer '{name}'. Available: {self._mixer_names}") self.active_mixer_name = name @@ -2756,6 +2798,29 @@ def get_active_mixer(self) -> str: """Get the name of the currently active mixer.""" return self.active_mixer_name + def offload_inactive_mixers(self) -> int: + """Move inactive mixer weights to CPU. Returns bytes freed. + + Also removes offloaded mixers from self.mixers (nn.ModuleDict) and + stores them in self._offloaded_mixers (plain dict). This is critical: + torch.compile/dynamo sets guards on every parameter it sees in the + module tree. If offloaded params stay in the tree with device='cpu', + a subsequent forward triggers guard failure → re-trace → crash. + Hiding them from the module tree avoids the issue entirely. + """ + freed = 0 + device_cpu = torch.device("cpu") + self._offloaded_mixers: dict[str, nn.Module] = {} + to_offload = [name for name in self.mixers if name != self.active_mixer_name] + for name in to_offload: + mixer = self.mixers[name] + for param in mixer.parameters(): + freed += param.data.nbytes + _move_module_device(mixer, device_cpu) + self._offloaded_mixers[name] = mixer + del self.mixers[name] + return freed + def forward( self, hidden_states: torch.Tensor, @@ -3198,6 +3263,28 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + def offload_inactive_mixers(self) -> int: + """Offload all inactive mixer weights to CPU. + + Cannot run inside load_weights() because torch.compile captures ALL + parameters as compiled graph inputs — profile_run() would crash trying + to pass CPU tensors to the GPU graph. Instead, this is called from the + monkey-patched Worker.determine_available_memory() AFTER profile_run(). + + Returns: + Total bytes freed on GPU. + """ + total_freed = 0 + for layer in self.model.layers: + if isinstance(layer, Apriel2StochasticDecoderLayer): + total_freed += layer.mixer.offload_inactive_mixers() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + freed_gib = total_freed / (1024**3) + logger.info(f"Offloaded inactive mixer weights to CPU: freed {freed_gib:.2f} GiB GPU memory") + return total_freed + def set_layer_placements(self, placement: list[str]) -> dict[int, str]: """Set the active mixer for each stochastic layer. @@ -3346,10 +3433,59 @@ def _recapture_cuda_graphs(worker) -> None: model_runner.capture_model() logger.info("CUDA graph re-capture complete") + def _swap_mixer_weights(worker, placement: list[str]) -> None: + """Swap mixer weights between GPU and CPU for placement change. + + For each layer where the active mixer changes: + 1. Offload old active → CPU, remove from ModuleDict, store in _offloaded_mixers + 2. Load new active from _offloaded_mixers → GPU, add to ModuleDict + Done layer by layer to avoid transient OOM. + """ + model = worker.get_model() + device_gpu = torch.device("cuda") + device_cpu = torch.device("cpu") + loaded = 0 + offloaded = 0 + + for layer_idx, new_mixer_name in enumerate(placement): + if layer_idx >= len(model.model.layers): + break + layer = model.model.layers[layer_idx] + if not isinstance(layer, Apriel2StochasticDecoderLayer): + continue + + stochastic = layer.mixer + old_mixer_name = stochastic.active_mixer_name + if old_mixer_name == new_mixer_name: + continue + + # 1. Offload old active → CPU, hide from module tree + old_mixer = stochastic.mixers[old_mixer_name] + _move_module_device(old_mixer, device_cpu) + del stochastic.mixers[old_mixer_name] + stochastic._offloaded_mixers[old_mixer_name] = old_mixer + offloaded += 1 + + # 2. Load new active from offloaded → GPU, restore to module tree + new_mixer = stochastic._offloaded_mixers.pop(new_mixer_name) + _move_module_device(new_mixer, device_gpu) + stochastic.mixers[new_mixer_name] = new_mixer + loaded += 1 + + if loaded or offloaded: + torch.cuda.synchronize() + torch.cuda.empty_cache() + logger.info(f"Weight swap: offloaded {offloaded} mixers to CPU, " f"loaded {loaded} mixers to GPU") + def _set_layer_placements(self, placement: list[str]) -> dict[int, str]: # Clear KV cache BEFORE changing placement to prevent reading stale data # written by a different mixer type (which could cause NaN errors) _clear_kv_cache(self) + + # Swap weights before changing active mixer (needs old active_mixer_name) + if APRIEL2_OFFLOAD_INACTIVE_MIXERS: + _swap_mixer_weights(self, placement) + result = self.get_model().set_layer_placements(placement) # Re-capture CUDA graphs with the new layout baked in if APRIEL2_FULL_CUDA_GRAPHS and result: @@ -3359,6 +3495,33 @@ def _set_layer_placements(self, placement: list[str]) -> dict[int, str]: def _get_mixer_names(self) -> tuple[str, ...]: return self.get_model().get_mixer_names() + # -- Weight offloading: patch determine_available_memory --------------- + # torch.compile captures ALL parameters as graph inputs. If we offload + # during load_weights(), profile_run() crashes (CPU tensors in GPU graph). + # Instead, offload AFTER profile_run() but adjust available memory upward. + + if APRIEL2_OFFLOAD_INACTIVE_MIXERS: + _orig_determine_available_memory = Worker.determine_available_memory + + @torch.inference_mode() + def _determine_available_memory_with_offload(self) -> int: + result = _orig_determine_available_memory(self) + + # Offload inactive mixer weights to CPU + freed = self.get_model().offload_inactive_mixers() + if freed > 0: + self.available_kv_cache_memory_bytes += freed + self.model_runner.model_memory_usage -= freed + result = int(self.available_kv_cache_memory_bytes) + logger.info( + "Adjusted available KV cache memory: +%.2f GiB from weight offloading", + freed / (1024**3), + ) + + return result + + Worker.determine_available_memory = _determine_available_memory_with_offload + Worker.get_layer_placements = _get_layer_placements Worker.set_layer_placements = _set_layer_placements Worker.get_mixer_names = _get_mixer_names From 2dfe6a7ebde3c1a1aa4802202b7f2d863792898b Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 3 Mar 2026 21:14:40 +0000 Subject: [PATCH 04/13] wip: FULL graph investigation. Previously Apriel2 set is_hybrid=False and bypassed vLLM's config verification entirely because the default dispatch couldn't handle heterogeneous or pure-mamba compositions. These changes introduce a proper custom config handler that correctly routes to the right vLLM verification path based on the actual layer composition, and provides the get_mamba_state_shape/dtype methods needed for hybrid page size alignment. --- .../apriel2/vllm/config_convertor.py | 46 ++++++- .../apriel2/vllm/modeling_apriel2.py | 119 ++++++++++++++++-- 2 files changed, 156 insertions(+), 9 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/config_convertor.py b/fast_llm_external_models/apriel2/vllm/config_convertor.py index 02bbecb7a..86c890559 100644 --- a/fast_llm_external_models/apriel2/vllm/config_convertor.py +++ b/fast_llm_external_models/apriel2/vllm/config_convertor.py @@ -9,6 +9,12 @@ """ from vllm import ModelRegistry +from vllm.model_executor.models.config import ( + MODELS_CONFIG_MAP, + HybridAttentionMambaModelConfig, + MambaModelConfig, + VerifyAndUpdateConfig, +) from vllm.transformers_utils.model_arch_config_convertor import ( MODEL_ARCH_CONFIG_CONVERTORS, ModelArchConfigConvertorBase, @@ -66,6 +72,39 @@ def get_head_size(self) -> int: return self._get_first_attention_block().get("head_size", 0) +class Apriel2ModelConfig(VerifyAndUpdateConfig): + """Config handler for Apriel2 models with heterogeneous mixer types. + + Apriel2 can be pure-attention, pure-mamba, or hybrid (attention + mamba) + depending on the decoder config. vLLM's default ``is_hybrid`` dispatch + calls ``HybridAttentionMambaModelConfig`` which crashes for pure-mamba + models (``ZeroDivisionError`` when ``num_kv_heads=0``). + + This handler inspects ``layers_block_type`` on the HF config to determine + the model composition and routes to the correct config handler. + """ + + @staticmethod + def verify_and_update_config(vllm_config) -> None: + hf_config = vllm_config.model_config.hf_config + layer_types = getattr(hf_config, "layers_block_type", None) + + if layer_types is None: + # Fallback: no layer type info — assume standard transformer. + return + + has_attention = any(t == "attention" for t in layer_types) + has_mamba = any(t == "mamba" for t in layer_types) + + if has_attention and has_mamba: + # Hybrid: attention + mamba page size alignment required. + HybridAttentionMambaModelConfig.verify_and_update_config(vllm_config) + elif has_mamba: + # Pure mamba: enable FULL_AND_PIECEWISE, set mamba_block_size. + MambaModelConfig.verify_and_update_config(vllm_config) + # Pure attention: no special config needed. + + def register(): """Register Apriel2 models and config convertors with vLLM. @@ -126,7 +165,7 @@ def register(): # Best-effort only; vLLM can still proceed with the generic config. pass - # Register model class + # Register model class and config handler. # Note: some exported checkpoints may list "Apriel2ForConditionalGeneration" # in config.json's "architectures". vLLM's model selection is driven by that # field, so we alias it to the same vLLM implementation for text-only usage. @@ -135,3 +174,8 @@ def register(): arch, "fast_llm_external_models.apriel2.vllm:Apriel2ForCausalLM", ) + # Register in MODELS_CONFIG_MAP so vLLM calls our handler instead of + # relying on the is_hybrid class attribute dispatch (which can't handle + # models that are sometimes hybrid, sometimes pure-mamba). + if arch not in MODELS_CONFIG_MAP: + MODELS_CONFIG_MAP[arch] = Apriel2ModelConfig diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 7cee923ed..26177eb3b 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -22,7 +22,7 @@ from torch import nn from transformers import PretrainedConfig from transformers.activations import ACT2FN -from vllm.attention.layer import Attention +from vllm import __version_tuple__ as _vllm_version from vllm.compilation.decorators import support_torch_compile from vllm.compilation.monitor import cudagraph_capturing_enabled, validate_cudagraph_capturing_enabled from vllm.config import CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config @@ -69,6 +69,11 @@ from vllm.v1.attention.selector import get_mamba_attn_backend from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec +if _vllm_version >= (0, 16, 0): + from vllm.model_executor.layers.attention import Attention +else: + from vllm.attention.layer import Attention + # Lazy triton allocator setup - only called when a triton kernel needs scratch memory _triton_allocator_installed = False @@ -104,6 +109,7 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i DEBUG_DECODER_LAYER = False # Debug decoder layer outputs (residual, norm) DEBUG_FINAL_NORM = False # Debug final norm before LM head DEBUG_LM_HEAD = False # Debug LM head input/output +DEBUG_STATE_INDICES = os.environ.get("APRIEL2_DEBUG_STATE_INDICES", "0") == "1" # ============================================================================= # CUDA Graph Modes for Stochastic Mixers @@ -748,17 +754,28 @@ def __init__( super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + # Attention-type mixers that use KV cache (not recurrent state) + _ATTENTION_MIXER_TYPES = {"attention", "sliding_window"} + @property def layers_block_type(self) -> list[str]: - """Return block types for each layer (for hybrid model detection).""" + """Return block types for each layer, normalized to 'attention' or 'mamba'. + + vLLM's get_num_layers_by_block_type() expects these two canonical types. + All attention-like mixers map to 'attention'; all recurrent mixers + (GDN, KDA, Mamba) map to 'mamba'. + """ decoder_config = self.decoder seq_type = decoder_config.get("type", "fixed") num_blocks = decoder_config.get("num_blocks", self.num_hidden_layers) + def _normalize(mixer_type: str) -> str: + return "attention" if mixer_type in self._ATTENTION_MIXER_TYPES else "mamba" + if seq_type == "fixed": block_config = decoder_config.get("block", {}) mixer_type = block_config.get("mixer", {}).get("type", "attention") - return [mixer_type] * num_blocks + return [_normalize(mixer_type)] * num_blocks elif seq_type == "pattern": pattern = decoder_config.get("pattern", ["attention"]) blocks_config = decoder_config.get("blocks", {}) @@ -766,7 +783,7 @@ def layers_block_type(self) -> list[str]: for i in range(num_blocks): block_name = pattern[i % len(pattern)] mixer_type = blocks_config.get(block_name, {}).get("mixer", {}).get("type", "attention") - result.append(mixer_type) + result.append(_normalize(mixer_type)) return result return ["attention"] * num_blocks @@ -1767,6 +1784,17 @@ def _forward_core( non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor num_actual_tokens = attn_metadata.num_actual_tokens + if DEBUG_STATE_INDICES and not cudagraph_capturing_enabled: + indices = non_spec_state_indices_tensor[:num_actual_tokens].tolist() + unique = {i for i in indices if i >= 0} + dup = len(indices) - len(unique) - indices.count(-1) + tag = " ** DUPLICATE **" if dup > 0 else "" + print( + f"[STATE-IDX GDN {self.prefix}] " + f"decodes={attn_metadata.num_decodes} prefills={attn_metadata.num_prefills} " + f"indices={indices}{tag}" + ) + # self._debug_print(f"num_actual_tokens={num_actual_tokens}, num_prefills={attn_metadata.num_prefills}, num_decodes={attn_metadata.num_decodes}") # self._debug_print(f"has_initial_state={has_initial_state}") # self._debug_print(f"non_spec_query_start_loc={non_spec_query_start_loc}") @@ -2205,6 +2233,17 @@ def _forward( num_actual_tokens = attn_metadata.num_actual_tokens constant_caches = self.kv_cache[forward_context.virtual_engine] + if DEBUG_STATE_INDICES and not cudagraph_capturing_enabled: + indices = non_spec_state_indices_tensor[:num_actual_tokens].tolist() + unique = {i for i in indices if i >= 0} + dup = len(indices) - len(unique) - indices.count(-1) + tag = " ** DUPLICATE **" if dup > 0 else "" + print( + f"[STATE-IDX KDA {self.prefix}] " + f"decodes={attn_metadata.num_decodes} prefills={attn_metadata.num_prefills} " + f"indices={indices}{tag}" + ) + q_proj_states = q_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens] v_proj_states = v_proj_states[:num_actual_tokens] @@ -3160,13 +3199,77 @@ class Apriel2ForCausalLM(nn.Module, HasInnerState, SupportsPP): }, ) - # For hybrid models + # has_inner_state = True: model uses recurrent state (GDN, KDA, Mamba). + # is_hybrid = False: we handle config verification ourselves via + # MODELS_CONFIG_MAP (see config_convertor.py) rather than relying on + # vLLM's default HybridAttentionMambaModelConfig dispatch. This is + # necessary because the default dispatch crashes for pure-mamba models + # (ZeroDivisionError: num_kv_heads=0 when no attention blocks exist). has_inner_state = True - # Don't use is_hybrid=True - it triggers HybridAttentionMambaModelConfig - # which assumes all mamba-like layers have the same shape. - # Apriel2 has heterogeneous blocks, each with its own get_kv_cache_spec(). is_hybrid = False + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: VllmConfig, + ) -> tuple: + """Return the largest mamba state shape across all mixer types. + + HybridAttentionMambaModelConfig.verify_and_update_config() calls this + to compute page size alignment. It only needs a conservative upper + bound — per-layer get_kv_cache_spec() handles actual heterogeneous + allocation. We return the shape of whichever mixer type produces the + largest page_size_bytes (the "envelope"). + """ + config = vllm_config.model_config.hf_config + decoder_config = getattr(config, "decoder", {}) or {} + blocks_config = get_blocks_config(decoder_config) + block_params = get_block_params(blocks_config, vllm_config) + + # Find the mamba block with the largest natural page size + best_shapes: tuple | None = None + best_page_size = 0 + for params in block_params.values(): + if isinstance(params, MambaBlockParams): + if params.natural_page_size > best_page_size: + best_page_size = params.natural_page_size + best_shapes = params.shapes + + if best_shapes is None: + # Pure attention model — return minimal shapes so + # verify_and_update_config() sees mamba_page_size=0 and returns. + return ((1, 1), (1, 1)) + + return best_shapes + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: VllmConfig, + ) -> tuple: + """Return dtypes matching the envelope mamba state shape. + + Must be consistent with get_mamba_state_shape_from_config() — returns + the dtypes of whichever mixer type has the largest page size. + """ + config = vllm_config.model_config.hf_config + decoder_config = getattr(config, "decoder", {}) or {} + blocks_config = get_blocks_config(decoder_config) + block_params = get_block_params(blocks_config, vllm_config) + + best_dtypes: tuple | None = None + best_page_size = 0 + for params in block_params.values(): + if isinstance(params, MambaBlockParams): + if params.natural_page_size > best_page_size: + best_page_size = params.natural_page_size + best_dtypes = params.dtypes + + if best_dtypes is None: + return (torch.float32, torch.float32) + + return best_dtypes + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() # Install triton allocator lazily - this runs in the vLLM subprocess From 899d4da576cdfa632918ddc737a4fdbfb16ed468 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 3 Mar 2026 21:18:34 +0000 Subject: [PATCH 05/13] docuemnt new env vars --- .../apriel2/vllm/modeling_apriel2.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 26177eb3b..4dec49b9d 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -109,6 +109,11 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i DEBUG_DECODER_LAYER = False # Debug decoder layer outputs (residual, norm) DEBUG_FINAL_NORM = False # Debug final norm before LM head DEBUG_LM_HEAD = False # Debug LM head input/output +# Log ssm_state_indices at every GDN/KDA forward (decode only, suppressed +# during CUDA graph capture). Flags duplicate block IDs with "** DUPLICATE **". +# NOTE: In FULL CUDA graph mode, model-side logging never executes during replay; +# use the GDN metadata builder logging in gdn_attn.py instead (which also reads +# this env var). DEBUG_STATE_INDICES = os.environ.get("APRIEL2_DEBUG_STATE_INDICES", "0") == "1" # ============================================================================= @@ -227,14 +232,36 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i # can be faster than attention in Mode 1 (FULL graphs). # # ============================================================================= +# Capture the entire forward (including stochastic dispatch) as one monolithic +# CUDA graph per batch size. On placement change, ALL graphs are invalidated +# and re-captured via capture_model() (~5-15 s). When disabled, the dispatch +# op is registered as a graph-splitting op, forcing PIECEWISE mode where Python +# selects the active mixer each step (no re-capture needed). +# See Mode 1 vs Mode 2/3 in the table above. +# Default: "1" (enabled). APRIEL2_FULL_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_FULL_CUDA_GRAPHS", "1") == "1" + +# Only relevant when APRIEL2_FULL_CUDA_GRAPHS=0 (PIECEWISE mode). +# Caches a separate small CUDA graph per (mixer_name, batch_size) at each +# stochastic dispatch point. Creates ~5k graphs (~2.4 GiB), adding memory +# pressure but slightly faster than fully-eager dispatch. +# See Mode 2 vs Mode 3 in the table above. +# Default: "0" (disabled — the throughput gain is marginal). APRIEL2_MIXER_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_MIXER_CUDA_GRAPHS", "0") == "1" # Offload inactive mixer weights to CPU after profile_run(). Frees ~19 GiB # GPU memory for KV cache. On layout switch, weights are swapped layer by layer. -# Note: cannot offload during load_weights() because torch.compile captures all -# parameters as graph inputs — profile_run() would crash with CPU tensors. -# Default: enabled with FULL CUDA graphs (layout is fixed per graph capture). +# +# Constraints: +# - Cannot offload during load_weights(): torch.compile captures all parameters +# as graph inputs — profile_run() would crash with CPU tensors. +# - Only moves parameters, NOT buffers: RotaryEmbedding.cos_sin_cache is shared +# across mixer instances (via get_rope() LRU cache); moving it corrupts the +# active mixer. +# - Offloaded modules are removed from nn.ModuleDict and stored in a plain dict +# (_offloaded_mixers) to prevent torch.compile guard invalidation. +# +# Default: "1" when FULL graphs enabled (layout fixed per capture), "0" otherwise. APRIEL2_OFFLOAD_INACTIVE_MIXERS: bool = ( os.environ.get("APRIEL2_OFFLOAD_INACTIVE_MIXERS", "1" if APRIEL2_FULL_CUDA_GRAPHS else "0") == "1" ) From 9a6562db51fa5cdc8824637fe6341e0d02714765 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Mar 2026 21:15:16 +0000 Subject: [PATCH 06/13] Fix GDN/KDA crash when CUDA graph capture batch exceeds KV cache blocks During FULL CUDA graph capture, vLLM creates dummy runs with batch sizes up to max_cudagraph_capture_size (e.g. 512). When gpu_memory_utilization is low, the number of allocated KV cache blocks can be smaller than this capture batch size (e.g. 282 blocks for 0.05 utilization). This caused an assertion failure in causal_conv1d_update: assert num_cache_lines >= batch The root cause: GDN and KDA decode paths used num_actual_tokens (which equals the capture batch size during FULL graph capture) to slice conv_state_indices and cu_seqlens, but the conv/recurrent state tensors only have num_cache_blocks rows. Fix: clamp num_actual_tokens to conv_state.size(0) at the top of both GDN._forward_core() and KDA._forward(). This ensures all downstream kernel calls (causal_conv1d_update, fused_recurrent_gated_delta_rule, fused_recurrent_kda) receive batch sizes that fit within the allocated cache. During normal inference the scheduler guarantees num_actual_tokens <= cache blocks, so the clamp is a no-op. Co-Authored-By: Claude Opus 4.6 --- .../apriel2/vllm/modeling_apriel2.py | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 4dec49b9d..b328363d6 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -1830,8 +1830,14 @@ def _forward_core( conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - # self._debug_tensor("conv_state (from cache)", conv_state) - # self._debug_tensor("ssm_state (from cache)", ssm_state) + # During FULL CUDA graph capture the dummy batch size can exceed + # the number of allocated KV cache blocks (e.g. capture=512 but + # only 282 cache lines). Clamp to avoid out-of-bounds access in + # conv1d and recurrent kernels. At inference time the scheduler + # guarantees num_actual_tokens <= cache lines, so this is a no-op. + num_cache_lines = conv_state.size(0) + if num_actual_tokens > num_cache_lines: + num_actual_tokens = num_cache_lines mixed_qkv = mixed_qkv[:num_actual_tokens] b = b[:num_actual_tokens] @@ -1864,7 +1870,7 @@ def _forward_core( ).transpose(0, 1) else: # self._debug_print("Using causal_conv1d_update (decode path)") - mixed_qkv = causal_conv1d_update( + causal_conv1d_update( mixed_qkv, conv_state, conv_weights, @@ -1950,6 +1956,8 @@ def _forward_core( print( f"[vLLM-GDN {self.prefix}] DECODE inputs: q={query.flatten()[:4].tolist()}, k={key.flatten()[:4].tolist()}, v={value.flatten()[:4].tolist()}, g={g.flatten()[:4].tolist()}, beta={beta.flatten()[:4].tolist()}" ) + # num_actual_tokens already clamped to cache lines above + num_decodes = min(attn_metadata.num_decodes, num_actual_tokens) core_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -1958,8 +1966,8 @@ def _forward_core( beta=beta, initial_state=ssm_state, inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc[: num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], use_qk_l2norm_in_kernel=True, ) # self._debug_tensor("core_out (from fused_recurrent)", core_out) @@ -2260,6 +2268,18 @@ def _forward( num_actual_tokens = attn_metadata.num_actual_tokens constant_caches = self.kv_cache[forward_context.virtual_engine] + (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches + conv_state_q = conv_state_q.transpose(-1, -2) + conv_state_k = conv_state_k.transpose(-1, -2) + conv_state_v = conv_state_v.transpose(-1, -2) + + # During FULL CUDA graph capture the dummy batch size can exceed + # the number of allocated KV cache blocks. Clamp to avoid + # out-of-bounds access. At inference time this is a no-op. + num_cache_lines = conv_state_q.size(0) + if num_actual_tokens > num_cache_lines: + num_actual_tokens = num_cache_lines + if DEBUG_STATE_INDICES and not cudagraph_capturing_enabled: indices = non_spec_state_indices_tensor[:num_actual_tokens].tolist() unique = {i for i in indices if i >= 0} @@ -2277,11 +2297,6 @@ def _forward( g1 = g1[:num_actual_tokens] beta = beta[:num_actual_tokens] - (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches - conv_state_q = conv_state_q.transpose(-1, -2) - conv_state_k = conv_state_k.transpose(-1, -2) - conv_state_v = conv_state_v.transpose(-1, -2) - q_conv_weights = self.q_conv1d.weight.view(self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)) k_conv_weights = self.k_conv1d.weight.view(self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)) v_conv_weights = self.v_conv1d.weight.view(self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)) @@ -2372,6 +2387,7 @@ def _forward( ) recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state else: + num_decodes = min(attn_metadata.num_decodes, num_actual_tokens) core_attn_out_non_spec, _ = fused_recurrent_kda( q=q, k=k, @@ -2380,8 +2396,8 @@ def _forward( beta=beta, initial_state=recurrent_state, use_qk_l2norm_in_kernel=True, - cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc[: num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], ) core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[0, :num_actual_tokens] From 7a948a4c6095ce7e2189f22c60a710d504bc5bde Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 12 Mar 2026 15:38:11 +0000 Subject: [PATCH 07/13] Add unconditional CUDA sync debug points + KV cache grouping fix for hybrid models Three categories of changes for debugging and fixing throughput regression in hybrid Apriel2 models with singleton mixer types (e.g. a12_g1_k11): 1. KV cache grouping monkey-patch: use max(num_layers_per_type) as group_size when min <= 2, reducing O(num_layers) groups to O(num_types) groups. 2. Always register apriel2_gdn_attention_core as PIECEWISE splitting op, fix fused_gdn_gating_kernel total_elements constexpr, and only clamp num_actual_tokens during CUDA graph capture (not during prefill). 3. Unconditional CUDA sync points (SYNC-1 through SYNC-6) with print(flush=True) to pinpoint illegal memory access during large re-prefill after preemption. Co-Authored-By: Claude Opus 4.6 --- .../apriel2/vllm/modeling_apriel2.py | 174 +++++++++++++++--- 1 file changed, 151 insertions(+), 23 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index b328363d6..1674ca954 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -74,6 +74,73 @@ else: from vllm.attention.layer import Attention + +# --------------------------------------------------------------------------- +# Monkey-patch: fix vLLM's KV cache group_size for heterogeneous hybrid models +# --------------------------------------------------------------------------- +# vLLM's _get_kv_cache_groups_uniform_page_size computes +# group_size = min(num_layers_per_type) +# which was designed for models with simple n:1 repeating patterns (Gemma3, +# LLaMA4). For Apriel2 supernets with arbitrary placement ratios (e.g. 12 +# attention + 1 GDN + 11 KDA), min = 1 creates one KV-cache group per layer +# (24 total). Each group triggers a metadata rebuild per forward step, +# causing >2× throughput regression on small models. +# +# The fix: use max(num_layers_per_type) as group_size. This gives at most +# num_types groups (one per spec type). Tensor positions without a layer of +# a given type are simply unused by that type — no memory is wasted because +# the physical tensor at each position is shared across types. +def _patch_kv_cache_grouping() -> None: + import vllm.v1.core.kv_cache_utils as _kcu + + _original = _kcu._get_kv_cache_groups_uniform_page_size + + def _patched(kv_cache_spec: dict) -> list: + from collections import defaultdict + from math import ceil as _ceil + + same_type_layers: defaultdict[object, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec].append(layer_name) + + min_num = min(len(v) for v in same_type_layers.values()) + max_num = max(len(v) for v in same_type_layers.values()) + + if min_num > 2 or max_num < min_num * 1.25: + # Not a singleton/near-singleton case — use original logic. + return _original(kv_cache_spec) + + # Singleton / near-singleton type detected. + # Use max_num as group_size to produce at most num_types groups. + group_size = max_num + apriel2_logger.info( + "KV cache grouping: using group_size=%d (max) instead of %d " + "(min) to avoid O(num_layers) groups with %d spec types", + group_size, + min_num, + len(same_type_layers), + ) + + grouped_layers = [] + for layers in same_type_layers.values(): + num_padding = group_size - len(layers) % group_size + if num_padding != group_size: + apriel2_logger.info( + "KV cache grouping: %d padding layers for type with " "%d real layers (group_size=%d)", + num_padding, + len(layers), + group_size, + ) + num_groups = _ceil(len(layers) / group_size) + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) + return _kcu.create_kv_cache_group_specs(kv_cache_spec, grouped_layers) + + _kcu._get_kv_cache_groups_uniform_page_size = _patched + + +_patch_kv_cache_grouping() + # Lazy triton allocator setup - only called when a triton kernel needs scratch memory _triton_allocator_installed = False @@ -115,6 +182,9 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i # use the GDN metadata builder logging in gdn_attn.py instead (which also reads # this env var). DEBUG_STATE_INDICES = os.environ.get("APRIEL2_DEBUG_STATE_INDICES", "0") == "1" +# Sync CUDA before/between GDN kernels to catch the exact source of async +# illegal-memory-access errors. Very slow — only for debugging. +DEBUG_SYNC = os.environ.get("APRIEL2_DEBUG_SYNC", "0") == "1" # ============================================================================= # CUDA Graph Modes for Stochastic Mixers @@ -1018,7 +1088,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - """Return cache spec for attention with unified page size for hybrid models.""" + """Return cache spec for attention with unified page size for hybrid models. + + Returns SlidingWindowSpec for sliding-window layers and FullAttentionSpec + for regular attention. This puts them in separate KV cache groups (and + therefore separate FlashInfer metadata builders), which avoids the + "Window left is not the same for all layers" error. The monkey-patched + grouping function handles the potential singleton-group degeneration. + """ config = vllm_config.model_config.hf_config block_size, _ = get_unified_page_size_for_config(config, vllm_config) @@ -1039,13 +1116,13 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: dtype=kv_cache_dtype, sliding_window=self.window_size, ) - else: - return FullAttentionSpec( - block_size=block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_dim, - dtype=kv_cache_dtype, - ) + + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + dtype=kv_cache_dtype, + ) class Apriel2MambaMixer(nn.Module): @@ -1279,19 +1356,28 @@ def stochastic_mixer_dispatch_fake( fake_impl=stochastic_mixer_dispatch_fake, ) -# Add stochastic_mixer_dispatch to vLLM's splitting ops so it causes graph breaks. -# Only needed in PIECEWISE mode — in FULL mode the dispatch is transparent to -# graph capture (the active mixer's kernels get baked into the full graph). -if not APRIEL2_FULL_CUDA_GRAPHS: - try: - from vllm.config.compilation import CompilationConfig, CUDAGraphMode - +# Register custom ops as PIECEWISE splitting ops so they cause graph breaks. +# The Apriel2 GDN op MUST always be registered: even in FULL_AND_PIECEWISE mode, +# PIECEWISE graphs handle prefill. Without the graph break, the decode path gets +# baked into a compiled piece; prefill batches then replay the wrong path → +# illegal memory access. The stochastic dispatch op is only needed when +# FULL graphs are disabled (in FULL mode, the active mixer's kernels get baked +# into the full graph transparently). +try: + from vllm.config.compilation import CompilationConfig, CUDAGraphMode + + _gdn_op = "vllm::apriel2_gdn_attention_core" + if _gdn_op not in CompilationConfig._attention_ops: + CompilationConfig._attention_ops.append(_gdn_op) + logger.info(f"Added {_gdn_op} to vLLM splitting ops") + + if not APRIEL2_FULL_CUDA_GRAPHS: _stochastic_op = "vllm::stochastic_mixer_dispatch" if _stochastic_op not in CompilationConfig._attention_ops: CompilationConfig._attention_ops.append(_stochastic_op) logger.info(f"Added {_stochastic_op} to vLLM splitting ops") - except ImportError: - logger.warning("Could not add stochastic_mixer_dispatch to vLLM splitting ops") +except ImportError: + logger.warning("Could not add custom ops to vLLM splitting ops") def _force_piecewise_cudagraph_for_stochastic_mixers(): @@ -1341,7 +1427,7 @@ def fused_gdn_gating_kernel( g_ptr, beta_ptr, num_heads: tl.constexpr, - total_elements: tl.constexpr, + total_elements, BLOCK_SIZE: tl.constexpr, SOFTPLUS_THRESHOLD: tl.constexpr, ): @@ -1720,6 +1806,10 @@ def forward( device=hidden_states.device, ) + if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + torch.cuda.synchronize() + print(f"[SYNC-1] {self.prefix}: pre-custom-op sync OK", flush=True) + torch.ops.vllm.apriel2_gdn_attention_core( mixed_qkv, b, @@ -1833,10 +1923,16 @@ def _forward_core( # During FULL CUDA graph capture the dummy batch size can exceed # the number of allocated KV cache blocks (e.g. capture=512 but # only 282 cache lines). Clamp to avoid out-of-bounds access in - # conv1d and recurrent kernels. At inference time the scheduler - # guarantees num_actual_tokens <= cache lines, so this is a no-op. + # conv1d and recurrent kernels. + # IMPORTANT: Only clamp during CUDA graph capture. During normal + # inference, num_actual_tokens is the number of *tokens* in the + # batch, which can exceed num_blocks (= number of *state slots*) + # during prefill (many tokens share the same state slot). + # Clamping during prefill would truncate the input and cause the + # conv1d/recurrent kernels to read out-of-bounds via cache_indices + # that still reference the full batch. num_cache_lines = conv_state.size(0) - if num_actual_tokens > num_cache_lines: + if cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: num_actual_tokens = num_cache_lines mixed_qkv = mixed_qkv[:num_actual_tokens] @@ -1857,6 +1953,7 @@ def _forward_core( # self._debug_print("Using causal_conv1d_fn (prefill path)") mixed_qkv_T = mixed_qkv.transpose(0, 1) # self._debug_tensor("mixed_qkv_T (before conv)", mixed_qkv_T) + mixed_qkv = causal_conv1d_fn( mixed_qkv_T, conv_weights, @@ -1868,6 +1965,9 @@ def _forward_core( query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) + if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + torch.cuda.synchronize() + print(f"[SYNC-2] {self.prefix}: post-causal_conv1d_fn sync OK", flush=True) else: # self._debug_print("Using causal_conv1d_update (decode path)") causal_conv1d_update( @@ -1899,13 +1999,33 @@ def _forward_core( # self._debug_tensor("A_log", self.A_log) # self._debug_tensor("dt_bias", self.dt_bias) + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + torch.cuda.synchronize() + print(f"[SYNC-3] {self.prefix}: post-fused_gdn_gating sync OK", flush=True) # self._debug_tensor("g (from gating)", g) # self._debug_tensor("beta (from gating)", beta) # Recurrent attention if attn_metadata.num_prefills > 0: # self._debug_print("Using chunk_gated_delta_rule (prefill)") + # Bounds check + diagnostics (sync-3 already caught async errors) + if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + max_idx = non_spec_state_indices_tensor.max().item() + print( + f"[SYNC-4] {self.prefix}: prefill ssm_state.shape={ssm_state.shape}, " + f"indices={non_spec_state_indices_tensor.tolist()}, " + f"max_idx={max_idx}, num_cache_lines={num_cache_lines}, " + f"has_initial_state={has_initial_state.tolist()}, " + f"num_prefills={attn_metadata.num_prefills}", + flush=True, + ) + if max_idx >= ssm_state.shape[0]: + print( + f"[OOB] {self.prefix}: max_idx={max_idx} >= ssm_state.shape[0]={ssm_state.shape[0]}!", + flush=True, + ) initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 # self._debug_tensor("initial_state", initial_state) @@ -1935,6 +2055,9 @@ def _forward_core( head_first=False, use_qk_l2norm_in_kernel=True, ) + if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + torch.cuda.synchronize() + print(f"[SYNC-5] {self.prefix}: post-chunk_gated_delta_rule sync OK", flush=True) # self._debug_tensor("core_out (from chunk_gated_delta_rule)", core_out) # self._debug_tensor("last_state", last_state) # # Debug prefill state - get seq_len from query_start_loc @@ -2231,6 +2354,10 @@ def forward( dtype=hidden_states.dtype, device=hidden_states.device, ) + if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + torch.cuda.synchronize() + print(f"[SYNC-6] {self.prefix}: pre-kda_attention sync OK", flush=True) + torch.ops.vllm.kda_attention( q, k, @@ -2275,9 +2402,10 @@ def _forward( # During FULL CUDA graph capture the dummy batch size can exceed # the number of allocated KV cache blocks. Clamp to avoid - # out-of-bounds access. At inference time this is a no-op. + # out-of-bounds access. + # IMPORTANT: Only clamp during CUDA graph capture — see GDN comment. num_cache_lines = conv_state_q.size(0) - if num_actual_tokens > num_cache_lines: + if cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: num_actual_tokens = num_cache_lines if DEBUG_STATE_INDICES and not cudagraph_capturing_enabled: From 60541b0a00582fd109ee2bb345aee030d989dad1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 12 Mar 2026 17:48:59 +0000 Subject: [PATCH 08/13] Fix cudagraph_capturing_enabled import-by-value bug causing OOB crash on re-prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `from vllm.compilation.monitor import cudagraph_capturing_enabled` copies the default value True at import time. When vLLM later calls set_cudagraph_capturing_enabled(False), the module global changes but the imported name stays True. This caused num_actual_tokens to be wrongly clamped to num_cache_lines during inference (not just CUDA graph capture), truncating mixed_qkv while query_start_loc still referenced the full 16384-token batch — making causal_conv1d_fn read OOB and crash. Fix: use _compile_monitor.cudagraph_capturing_enabled (module attribute access) for all runtime checks so the current value is always read. Debug sync points (SYNC-1..6) kept behind APRIEL2_DEBUG_SYNC=1 flag. Co-Authored-By: Claude Opus 4.6 --- .../apriel2/vllm/modeling_apriel2.py | 47 ++++++++++++++----- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 1674ca954..39d4ae559 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -17,6 +17,7 @@ import torch import triton +import vllm.compilation.monitor as _compile_monitor import vllm.model_executor.layers.kda # noqa: F401 from einops import rearrange from torch import nn @@ -24,7 +25,7 @@ from transformers.activations import ACT2FN from vllm import __version_tuple__ as _vllm_version from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.monitor import cudagraph_capturing_enabled, validate_cudagraph_capturing_enabled +from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id @@ -1319,7 +1320,11 @@ def stochastic_mixer_dispatch( from vllm.config.compilation import CUDAGraphMode runtime_mode = forward_context.cudagraph_runtime_mode - if cudagraph_capturing_enabled and runtime_mode is not None and runtime_mode != CUDAGraphMode.NONE: + if ( + _compile_monitor.cudagraph_capturing_enabled + and runtime_mode is not None + and runtime_mode != CUDAGraphMode.NONE + ): _capture_all_mixers_for_num_tokens(stochastic_mixer, cache, hidden_states, output, positions, num_tokens) return @@ -1806,7 +1811,7 @@ def forward( device=hidden_states.device, ) - if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: torch.cuda.synchronize() print(f"[SYNC-1] {self.prefix}: pre-custom-op sync OK", flush=True) @@ -1901,7 +1906,7 @@ def _forward_core( non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor num_actual_tokens = attn_metadata.num_actual_tokens - if DEBUG_STATE_INDICES and not cudagraph_capturing_enabled: + if DEBUG_STATE_INDICES and not _compile_monitor.cudagraph_capturing_enabled: indices = non_spec_state_indices_tensor[:num_actual_tokens].tolist() unique = {i for i in indices if i >= 0} dup = len(indices) - len(unique) - indices.count(-1) @@ -1932,7 +1937,7 @@ def _forward_core( # conv1d/recurrent kernels to read out-of-bounds via cache_indices # that still reference the full batch. num_cache_lines = conv_state.size(0) - if cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: + if _compile_monitor.cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: num_actual_tokens = num_cache_lines mixed_qkv = mixed_qkv[:num_actual_tokens] @@ -1954,6 +1959,23 @@ def _forward_core( mixed_qkv_T = mixed_qkv.transpose(0, 1) # self._debug_tensor("mixed_qkv_T (before conv)", mixed_qkv_T) + if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: + max_ci = non_spec_state_indices_tensor.max().item() + print( + f"[PRE-CONV] {self.prefix}: conv_state.shape={conv_state.shape}, " + f"x.shape={mixed_qkv_T.shape}, " + f"cache_indices={non_spec_state_indices_tensor.tolist()} (max={max_ci}), " + f"num_cache_lines={num_cache_lines}, " + f"query_start_loc={non_spec_query_start_loc.tolist()}, " + f"has_initial_state={has_initial_state.tolist()}, " + f"num_prefills={attn_metadata.num_prefills}", + flush=True, + ) + if max_ci >= num_cache_lines: + print( + f"[OOB-CONV] {self.prefix}: cache_index {max_ci} >= num_cache_lines {num_cache_lines}!", + flush=True, + ) mixed_qkv = causal_conv1d_fn( mixed_qkv_T, conv_weights, @@ -1965,7 +1987,7 @@ def _forward_core( query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) - if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: torch.cuda.synchronize() print(f"[SYNC-2] {self.prefix}: post-causal_conv1d_fn sync OK", flush=True) else: @@ -2001,7 +2023,7 @@ def _forward_core( # self._debug_tensor("dt_bias", self.dt_bias) g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: torch.cuda.synchronize() print(f"[SYNC-3] {self.prefix}: post-fused_gdn_gating sync OK", flush=True) # self._debug_tensor("g (from gating)", g) @@ -2010,8 +2032,7 @@ def _forward_core( # Recurrent attention if attn_metadata.num_prefills > 0: # self._debug_print("Using chunk_gated_delta_rule (prefill)") - # Bounds check + diagnostics (sync-3 already caught async errors) - if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: max_idx = non_spec_state_indices_tensor.max().item() print( f"[SYNC-4] {self.prefix}: prefill ssm_state.shape={ssm_state.shape}, " @@ -2055,7 +2076,7 @@ def _forward_core( head_first=False, use_qk_l2norm_in_kernel=True, ) - if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: torch.cuda.synchronize() print(f"[SYNC-5] {self.prefix}: post-chunk_gated_delta_rule sync OK", flush=True) # self._debug_tensor("core_out (from chunk_gated_delta_rule)", core_out) @@ -2354,7 +2375,7 @@ def forward( dtype=hidden_states.dtype, device=hidden_states.device, ) - if not torch.compiler.is_compiling() and not cudagraph_capturing_enabled: + if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: torch.cuda.synchronize() print(f"[SYNC-6] {self.prefix}: pre-kda_attention sync OK", flush=True) @@ -2405,10 +2426,10 @@ def _forward( # out-of-bounds access. # IMPORTANT: Only clamp during CUDA graph capture — see GDN comment. num_cache_lines = conv_state_q.size(0) - if cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: + if _compile_monitor.cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: num_actual_tokens = num_cache_lines - if DEBUG_STATE_INDICES and not cudagraph_capturing_enabled: + if DEBUG_STATE_INDICES and not _compile_monitor.cudagraph_capturing_enabled: indices = non_spec_state_indices_tensor[:num_actual_tokens].tolist() unique = {i for i in indices if i >= 0} dup = len(indices) - len(unique) - indices.count(-1) From be801717039bc4d834a891eea8ed27cdc2b0805e Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 9 Apr 2026 19:58:29 +0000 Subject: [PATCH 09/13] Clean up modeling_apriel2: remove WIP CUDA graph infrastructure, preserve bug fixes Remove code added by the 5 intermediate WIP commits (per-mixer graph caching, CPU offload for inactive mixers, FULL CUDA graph investigation, and associated documentation): - APRIEL2_FULL_CUDA_GRAPHS / APRIEL2_MIXER_CUDA_GRAPHS / APRIEL2_OFFLOAD_INACTIVE_MIXERS env vars - MixerGraphCache class and _capture_all_mixers_for_num_tokens - _move_module_device and CPU offload logic in StochasticMixer - cudagraph_capturing_enabled check in stochastic_mixer_dispatch - Extensive CUDA graph mode benchmark documentation (~300 lines) Preserve all changes from the 3 bug-fix commits (9a6562db, 7a948a4c, 60541b0a): - _patch_kv_cache_grouping() for heterogeneous hybrid model KV cache grouping - import vllm.compilation.monitor as _compile_monitor (fix import-by-value bug) - GDN/KDA clamp (cudagraph_capturing_enabled-conditional) to prevent OOB during capture - causal_conv1d_update in-place fix (don't assign return value) - num_decodes clamped to num_actual_tokens + ssm_state_indices[:num_actual_tokens] - apriel2_gdn_attention_core always registered as PIECEWISE splitting op - Attention.get_kv_cache_spec: remove spurious else branch - fused_gdn_gating_kernel total_elements: remove tl.constexpr - DEBUG_SYNC env var + SYNC-1..6 cuda sync debug points (gated behind APRIEL2_DEBUG_SYNC=1) Co-Authored-By: Claude Sonnet 4.6 --- .../apriel2/vllm/modeling_apriel2.py | 750 ++---------------- 1 file changed, 47 insertions(+), 703 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 39d4ae559..3dd822603 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -13,7 +13,7 @@ from collections.abc import Iterable from dataclasses import dataclass from itertools import islice -from typing import Callable, Literal +from typing import Literal import torch import triton @@ -23,12 +23,10 @@ from torch import nn from transformers import PretrainedConfig from transformers.activations import ACT2FN -from vllm import __version_tuple__ as _vllm_version +from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size -from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul @@ -64,16 +62,33 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import current_stream, direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.selector import get_mamba_attn_backend from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec -if _vllm_version >= (0, 16, 0): - from vllm.model_executor.layers.attention import Attention -else: - from vllm.attention.layer import Attention +# Lazy triton allocator setup - only called when a triton kernel needs scratch memory +_triton_allocator_installed = False + + +def _install_triton_allocator() -> None: + """Install triton allocator lazily to avoid early CUDA initialization.""" + global _triton_allocator_installed + if _triton_allocator_installed: + return + + def _triton_allocator(size: int, alignment: int, stream: int | None): # type: ignore[return] + return torch.empty(size, dtype=torch.int8, device="cuda").data_ptr() + + triton.set_allocator(_triton_allocator) + _triton_allocator_installed = True + + +logger = logging.getLogger(__name__) + + +apriel2_logger = init_logger(__name__) # --------------------------------------------------------------------------- @@ -127,7 +142,7 @@ def _patched(kv_cache_spec: dict) -> list: num_padding = group_size - len(layers) % group_size if num_padding != group_size: apriel2_logger.info( - "KV cache grouping: %d padding layers for type with " "%d real layers (group_size=%d)", + "KV cache grouping: %d padding layers for type with %d real layers (group_size=%d)", num_padding, len(layers), group_size, @@ -142,28 +157,6 @@ def _patched(kv_cache_spec: dict) -> list: _patch_kv_cache_grouping() -# Lazy triton allocator setup - only called when a triton kernel needs scratch memory -_triton_allocator_installed = False - - -def _install_triton_allocator() -> None: - """Install triton allocator lazily to avoid early CUDA initialization.""" - global _triton_allocator_installed - if _triton_allocator_installed: - return - - def _triton_allocator(size: int, alignment: int, stream: int | None): # type: ignore[return] - return torch.empty(size, dtype=torch.int8, device="cuda").data_ptr() - - triton.set_allocator(_triton_allocator) - _triton_allocator_installed = True - - -logger = logging.getLogger(__name__) - - -apriel2_logger = init_logger(__name__) - # ============================================================================= # Debug Flags # ============================================================================= @@ -177,273 +170,10 @@ def _triton_allocator(size: int, alignment: int, stream: int | None): # type: i DEBUG_DECODER_LAYER = False # Debug decoder layer outputs (residual, norm) DEBUG_FINAL_NORM = False # Debug final norm before LM head DEBUG_LM_HEAD = False # Debug LM head input/output -# Log ssm_state_indices at every GDN/KDA forward (decode only, suppressed -# during CUDA graph capture). Flags duplicate block IDs with "** DUPLICATE **". -# NOTE: In FULL CUDA graph mode, model-side logging never executes during replay; -# use the GDN metadata builder logging in gdn_attn.py instead (which also reads -# this env var). -DEBUG_STATE_INDICES = os.environ.get("APRIEL2_DEBUG_STATE_INDICES", "0") == "1" # Sync CUDA before/between GDN kernels to catch the exact source of async # illegal-memory-access errors. Very slow — only for debugging. DEBUG_SYNC = os.environ.get("APRIEL2_DEBUG_SYNC", "0") == "1" -# ============================================================================= -# CUDA Graph Modes for Stochastic Mixers -# ============================================================================= -# -# The Apriel2StochasticMixer wraps multiple sub-mixers (attention, GDN, KDA) -# per layer and routes to the active one at runtime. This interacts with -# vLLM's CUDA graph capture in several modes, controlled by env vars. -# -# Benchmark setup: 10 concurrent requests, all-attention layout, prompt -# length 1, max generation length 16k, REST backend, no warmup after -# local vLLM launch, single H100 80GB. -# -# ┌──────────────────────────────────────────────────────────────────────┐ -# │ Mode 0: Fixed-layout FULL graphs (baseline, no supernet) │ -# ├──────────────────────────────────────────────────────────────────────┤ -# │ Serve a checkpoint with a predefined layout (single mixer per │ -# │ layer). Standard vLLM FULL CUDA graph capture — no stochastic │ -# │ dispatch involved. │ -# │ │ -# │ Weights: 26.91 GiB KV cache: 43.50 GiB Graphs: 0.10 GiB │ -# │ Throughput: 583 tok/s │ -# │ │ -# │ This is the upper bound — all mixer weights are collapsed into one │ -# │ per layer, leaving maximum memory for KV cache. │ -# └──────────────────────────────────────────────────────────────────────┘ -# -# ┌──────────────────────────────────────────────────────────────────────┐ -# │ Mode 1: Supernet + FULL graphs + weight offload │ -# │ (APRIEL2_FULL_CUDA_GRAPHS=1, APRIEL2_OFFLOAD_INACTIVE_MIXERS=1)│ -# │ [default] │ -# ├──────────────────────────────────────────────────────────────────────┤ -# │ vLLM captures the entire forward as one CUDA graph per batch size. │ -# │ stochastic_mixer_dispatch is NOT a graph-splitting op: during │ -# │ capture it calls the active mixer, baking its GPU kernels into the │ -# │ graph. On replay the same kernels execute. │ -# │ │ -# │ After profile_run(), inactive mixer weights are offloaded to CPU. │ -# │ This reduces GPU weight memory to ~26.9 GiB (matching Mode 0) and │ -# │ frees ~19 GiB for KV cache. Only parameters are moved — shared │ -# │ buffers (e.g. RotaryEmbedding cos_sin_cache) stay on GPU. │ -# │ Offloaded mixers are also removed from nn.ModuleDict to avoid │ -# │ torch.compile guard invalidation. │ -# │ │ -# │ On layout change (set_layer_placements): │ -# │ 1. KV cache is cleared │ -# │ 2. Weights are swapped layer by layer (old→CPU, new→GPU) │ -# │ 3. All captured CUDA graphs are invalidated and re-captured │ -# │ via capture_model() (~5-15 s) │ -# │ │ -# │ Weights: 26.91 GiB KV cache: 43.50 GiB Graphs: 0.07 GiB │ -# │ Throughput: 516 tok/s │ -# └──────────────────────────────────────────────────────────────────────┘ -# -# ┌──────────────────────────────────────────────────────────────────────┐ -# │ Mode 1b: Supernet + FULL graphs, no offload │ -# │ (APRIEL2_FULL_CUDA_GRAPHS=1, APRIEL2_OFFLOAD_INACTIVE_MIXERS=0)│ -# ├──────────────────────────────────────────────────────────────────────┤ -# │ Same as Mode 1 but all mixer weights stay on GPU. This wastes │ -# │ ~19 GiB on inactive mixers, leaving 1.78× less KV cache. │ -# │ │ -# │ Weights: 45.92 GiB KV cache: 24.48 GiB Graphs: 0.09 GiB │ -# │ Throughput: ~200 tok/s │ -# └──────────────────────────────────────────────────────────────────────┘ -# -# ┌──────────────────────────────────────────────────────────────────────┐ -# │ Mode 2: Supernet + PIECEWISE + per-mixer sub-graphs │ -# │ (APRIEL2_FULL_CUDA_GRAPHS=0, APRIEL2_MIXER_CUDA_GRAPHS=1) │ -# ├──────────────────────────────────────────────────────────────────────┤ -# │ stochastic_mixer_dispatch is a graph-splitting op → vLLM forces │ -# │ PIECEWISE mode. At each dispatch point, a separate small CUDA │ -# │ graph is cached per (mixer, batch_size) and replayed. │ -# │ │ -# │ No re-capture needed on layout change (dispatch selects mixer at │ -# │ runtime), but creates ~5k graphs causing GPU memory pressure. │ -# │ │ -# │ Weights: 45.92 GiB KV cache: 24.48 GiB Graphs: 2.43 GiB │ -# │ Capture time: ~45 s │ -# │ Throughput: ~155 tok/s │ -# │ │ -# │ WARNING: Not recommended. The graph memory further reduces │ -# │ available KV cache and capture overhead is substantial. │ -# └──────────────────────────────────────────────────────────────────────┘ -# -# ┌──────────────────────────────────────────────────────────────────────┐ -# │ Mode 3: Supernet + PIECEWISE + eager dispatch │ -# │ (APRIEL2_FULL_CUDA_GRAPHS=0, APRIEL2_MIXER_CUDA_GRAPHS=0) │ -# ├──────────────────────────────────────────────────────────────────────┤ -# │ Same as Mode 2 but mixer forward runs fully eagerly (no per-mixer │ -# │ graph caching). Graph breaks at every dispatch, Python selects the │ -# │ active mixer each step. │ -# │ │ -# │ No re-capture needed on layout change. │ -# │ │ -# │ Weights: 45.92 GiB KV cache: 24.48 GiB Graphs: ~0 GiB │ -# │ Throughput: 151 tok/s │ -# └──────────────────────────────────────────────────────────────────────┘ -# -# Summary (H100 80 GB, all-attention layout, 10 concurrent reqs, 16k output, prompt length 1): -# -# Mode │ Supernet │ FULL │ Offload │ Per-mixer subgraph │ Weights │ KV cache │ Graphs │ tok/s -# ─────┼──────────┼──────┼─────────┼────────────────────┼─────────┼──────────┼────────┼────── -# 0 │ no │ on │ - │ - │ 26.9 Gi │ 43.5 Gi │ 0.1 Gi │ 583 -# 1 │ yes │ on │ on │ - │ 26.9 Gi │ 43.5 Gi │ 0.1 Gi │ 516 -# 1b │ yes │ on │ off │ - │ 45.9 Gi │ 24.5 Gi │ 0.1 Gi │ ~200 -# 2 │ yes │ off │ off │ on │ 45.9 Gi │ 24.5 Gi │ 2.4 Gi │ ~155 -# 3 │ yes │ off │ off │ off │ 45.9 Gi │ 24.5 Gi │ ~0 Gi │ ~151 -# -# FULL = APRIEL2_FULL_CUDA_GRAPHS -# Offload = APRIEL2_OFFLOAD_INACTIVE_MIXERS -# Per-mixer subgraph = APRIEL2_MIXER_CUDA_GRAPHS -# -# Note: CUDA graph capture is essential for linear mixers (GDN, KDA). -# They will be slower than attention in Mode 3 (eager dispatch) but -# can be faster than attention in Mode 1 (FULL graphs). -# -# ============================================================================= -# Capture the entire forward (including stochastic dispatch) as one monolithic -# CUDA graph per batch size. On placement change, ALL graphs are invalidated -# and re-captured via capture_model() (~5-15 s). When disabled, the dispatch -# op is registered as a graph-splitting op, forcing PIECEWISE mode where Python -# selects the active mixer each step (no re-capture needed). -# See Mode 1 vs Mode 2/3 in the table above. -# Default: "1" (enabled). -APRIEL2_FULL_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_FULL_CUDA_GRAPHS", "1") == "1" - -# Only relevant when APRIEL2_FULL_CUDA_GRAPHS=0 (PIECEWISE mode). -# Caches a separate small CUDA graph per (mixer_name, batch_size) at each -# stochastic dispatch point. Creates ~5k graphs (~2.4 GiB), adding memory -# pressure but slightly faster than fully-eager dispatch. -# See Mode 2 vs Mode 3 in the table above. -# Default: "0" (disabled — the throughput gain is marginal). -APRIEL2_MIXER_CUDA_GRAPHS: bool = os.environ.get("APRIEL2_MIXER_CUDA_GRAPHS", "0") == "1" - -# Offload inactive mixer weights to CPU after profile_run(). Frees ~19 GiB -# GPU memory for KV cache. On layout switch, weights are swapped layer by layer. -# -# Constraints: -# - Cannot offload during load_weights(): torch.compile captures all parameters -# as graph inputs — profile_run() would crash with CPU tensors. -# - Only moves parameters, NOT buffers: RotaryEmbedding.cos_sin_cache is shared -# across mixer instances (via get_rope() LRU cache); moving it corrupts the -# active mixer. -# - Offloaded modules are removed from nn.ModuleDict and stored in a plain dict -# (_offloaded_mixers) to prevent torch.compile guard invalidation. -# -# Default: "1" when FULL graphs enabled (layout fixed per capture), "0" otherwise. -APRIEL2_OFFLOAD_INACTIVE_MIXERS: bool = ( - os.environ.get("APRIEL2_OFFLOAD_INACTIVE_MIXERS", "1" if APRIEL2_FULL_CUDA_GRAPHS else "0") == "1" -) - - -def _move_module_device(module: nn.Module, device: torch.device) -> None: - """Move a module's weight parameters to a device. - - Uses param.data assignment (not module.to()) to preserve vLLM's - BasevLLMParameter metadata (_weight_loader attribute). - - Only moves parameters, NOT buffers. Buffers like RotaryEmbedding's - cos_sin_cache are shared across mixer instances (via get_rope() LRU cache). - Moving them would corrupt the active mixer's shared state. - """ - for param in module.parameters(): - param.data = param.data.to(device, non_blocking=True) - - -# ============================================================================= -# Per-Mixer CUDA Graph Cache -# ============================================================================= - - -@dataclass -class MixerGraphEntry: - """A captured CUDA graph for one mixer at one batch size.""" - - graph: torch.cuda.CUDAGraph - input_ptr: int # hidden_states.data_ptr() at capture time - output_ptr: int # output.data_ptr() at capture time - - -class MixerGraphCache: - """Per-mixer, per-batch-size CUDA graph cache for stochastic mixers. - - Each entry is keyed by (mixer_name, num_tokens). Capture happens during - vLLM's capture_model() phase (when cudagraph_capturing_enabled is True - and NCCL is in graph-safe mode). Replay happens during normal decode. - """ - - def __init__(self) -> None: - self._entries: dict[tuple[str, int], MixerGraphEntry] = {} - # Use a PRIVATE graph pool to avoid fragmenting vLLM's global pool. - # The global pool is shared with piecewise graph pieces; interleaving - # our ~5000 mixer captures with vLLM's piece captures degrades - # piecewise replay performance ~2x due to pool fragmentation. - self._graph_pool = torch.cuda.graph_pool_handle() - - def has(self, mixer_name: str, num_tokens: int) -> bool: - return (mixer_name, num_tokens) in self._entries - - def capture( - self, - mixer_name: str, - num_tokens: int, - mixer_fn: Callable, - hidden_states: torch.Tensor, - output: torch.Tensor, - positions: torch.Tensor | None, - ) -> None: - """Capture a CUDA graph for mixer_fn with the given inputs. - - Must be called during vLLM's capture_model() phase. - """ - validate_cudagraph_capturing_enabled() - key = (mixer_name, num_tokens) - if key in self._entries: - return - - graph = torch.cuda.CUDAGraph() - if self._graph_pool is not None: - set_graph_pool_id(self._graph_pool) - - with torch.cuda.graph(graph, pool=self._graph_pool, stream=current_stream()): - mixer_fn(hidden_states, output, positions=positions) - - self._entries[key] = MixerGraphEntry( - graph=graph, - input_ptr=hidden_states.data_ptr(), - output_ptr=output.data_ptr(), - ) - apriel2_logger.debug( - "MixerGraphCache: captured graph for (%s, %d), total=%d", - mixer_name, - num_tokens, - len(self._entries), - ) - - def replay( - self, - mixer_name: str, - num_tokens: int, - hidden_states: torch.Tensor, - output: torch.Tensor, - ) -> None: - """Replay the cached graph. Asserts pointer stability in debug mode.""" - entry = self._entries[(mixer_name, num_tokens)] - if __debug__: - assert hidden_states.data_ptr() == entry.input_ptr, ( - f"MixerGraphCache: hidden_states pointer changed between " - f"capture (0x{entry.input_ptr:x}) and replay " - f"(0x{hidden_states.data_ptr():x}) for ({mixer_name}, {num_tokens})" - ) - assert output.data_ptr() == entry.output_ptr, ( - f"MixerGraphCache: output pointer changed between " - f"capture (0x{entry.output_ptr:x}) and replay " - f"(0x{output.data_ptr():x}) for ({mixer_name}, {num_tokens})" - ) - entry.graph.replay() - # ============================================================================= # KV Cache Spec Computation @@ -852,28 +582,17 @@ def __init__( super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - # Attention-type mixers that use KV cache (not recurrent state) - _ATTENTION_MIXER_TYPES = {"attention", "sliding_window"} - @property def layers_block_type(self) -> list[str]: - """Return block types for each layer, normalized to 'attention' or 'mamba'. - - vLLM's get_num_layers_by_block_type() expects these two canonical types. - All attention-like mixers map to 'attention'; all recurrent mixers - (GDN, KDA, Mamba) map to 'mamba'. - """ + """Return block types for each layer (for hybrid model detection).""" decoder_config = self.decoder seq_type = decoder_config.get("type", "fixed") num_blocks = decoder_config.get("num_blocks", self.num_hidden_layers) - def _normalize(mixer_type: str) -> str: - return "attention" if mixer_type in self._ATTENTION_MIXER_TYPES else "mamba" - if seq_type == "fixed": block_config = decoder_config.get("block", {}) mixer_type = block_config.get("mixer", {}).get("type", "attention") - return [_normalize(mixer_type)] * num_blocks + return [mixer_type] * num_blocks elif seq_type == "pattern": pattern = decoder_config.get("pattern", ["attention"]) blocks_config = decoder_config.get("blocks", {}) @@ -881,7 +600,7 @@ def _normalize(mixer_type: str) -> str: for i in range(num_blocks): block_name = pattern[i % len(pattern)] mixer_type = blocks_config.get(block_name, {}).get("mixer", {}).get("type", "attention") - result.append(_normalize(mixer_type)) + result.append(mixer_type) return result return ["attention"] * num_blocks @@ -1089,14 +808,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - """Return cache spec for attention with unified page size for hybrid models. - - Returns SlidingWindowSpec for sliding-window layers and FullAttentionSpec - for regular attention. This puts them in separate KV cache groups (and - therefore separate FlashInfer metadata builders), which avoids the - "Window left is not the same for all layers" error. The monkey-patched - grouping function handles the potential singleton-group degeneration. - """ + """Return cache spec for attention with unified page size for hybrid models.""" config = vllm_config.model_config.hf_config block_size, _ = get_unified_page_size_for_config(config, vllm_config) @@ -1227,115 +939,20 @@ def apriel2_gdn_attention_core_fake( # ============================================================================= -def _batch_has_prefill(forward_context: ForwardContext, active_mixer: nn.Module) -> bool: - """Return True if this batch contains prefill tokens. - - CUDA graphs captured during decode cannot handle prefill, so we must - fall back to eager execution for mixed batches. - """ - from vllm.config.compilation import CUDAGraphMode - - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.NONE: - return True # Profile/warmup run — treat as non-graphable - - attn_meta = forward_context.attn_metadata - if isinstance(attn_meta, dict): - mixer_prefix = getattr(active_mixer, "prefix", None) - if mixer_prefix is not None: - meta = attn_meta.get(mixer_prefix) - if meta is not None and hasattr(meta, "num_prefills"): - return meta.num_prefills > 0 - return False - - -def _capture_all_mixers_for_num_tokens( - stochastic_mixer: "Apriel2StochasticMixer", - cache: MixerGraphCache, - hidden_states: torch.Tensor, - output: torch.Tensor, - positions: torch.Tensor | None, - num_tokens: int, -) -> None: - """Capture CUDA graphs for ALL sub-mixers at this batch size. - - Called during vLLM's capture_model() phase. All mixers are captured - writing into the same ``output`` buffer address so that any mixer's - graph can be replayed into the current output on future decode steps. - - After capturing, the active mixer is run eagerly once to leave the - correct result in ``output`` for downstream layers. - """ - active_name = stochastic_mixer.active_mixer_name - - for mixer_name, mixer in stochastic_mixer.mixers.items(): - if cache.has(mixer_name, num_tokens): - continue - - # Eager warmup: run the mixer once outside graph capture to trigger - # Triton autotuning (which calls cuda.synchronize() internally). - # Without this, autotuning during torch.cuda.graph() capture causes - # "operation not permitted when stream is capturing". - torch.cuda.synchronize() - output.zero_() - mixer(hidden_states, output, positions=positions) - - torch.cuda.synchronize() - output.zero_() - - apriel2_logger.info( - "Capturing mixer CUDA graph: layer=%s mixer=%s num_tokens=%d", - stochastic_mixer.prefix, - mixer_name, - num_tokens, - ) - cache.capture(mixer_name, num_tokens, mixer, hidden_states, output, positions) - - # Restore output to the active mixer's result for the outer capture pass - torch.cuda.synchronize() - stochastic_mixer.mixers[active_name](hidden_states, output, positions=positions) - - def stochastic_mixer_dispatch( hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor | None, layer_name: str, ) -> None: - """Dispatch to the active mixer; capture or replay CUDA graphs when enabled.""" + """Dispatch to the active mixer at runtime (escapes torch.compile).""" forward_context: ForwardContext = get_forward_context() stochastic_mixer = forward_context.no_compile_layers[layer_name] - active_name: str = stochastic_mixer.active_mixer_name - active_mixer = stochastic_mixer.mixers[active_name] - cache: MixerGraphCache | None = stochastic_mixer._mixer_graph_cache - - if cache is not None: - num_tokens = hidden_states.shape[0] - - # Capture phase: capture graphs for ALL mixers at this batch size. - # We must check runtime_mode to avoid capturing during profile_run, - # where cudagraph_capturing_enabled is True but cuBLAS hasn't been - # lazily initialized yet (CUBLAS_STATUS_NOT_INITIALIZED). - # During profile_run, runtime_mode is NONE; during capture_model() - # it is PIECEWISE. - from vllm.config.compilation import CUDAGraphMode - - runtime_mode = forward_context.cudagraph_runtime_mode - if ( - _compile_monitor.cudagraph_capturing_enabled - and runtime_mode is not None - and runtime_mode != CUDAGraphMode.NONE - ): - _capture_all_mixers_for_num_tokens(stochastic_mixer, cache, hidden_states, output, positions, num_tokens) - return - # Replay phase: use cached graph if available and batch is decode-only - has_prefill = _batch_has_prefill(forward_context, active_mixer) - has_cached = cache.has(active_name, num_tokens) - if not has_prefill and has_cached: - cache.replay(active_name, num_tokens, hidden_states, output) - return + # Get the currently active mixer (runtime lookup) + active_mixer = stochastic_mixer.mixers[stochastic_mixer.active_mixer_name] - # Eager fallback + # Forward through the active mixer active_mixer(hidden_states, output, positions=positions) @@ -1365,9 +982,7 @@ def stochastic_mixer_dispatch_fake( # The Apriel2 GDN op MUST always be registered: even in FULL_AND_PIECEWISE mode, # PIECEWISE graphs handle prefill. Without the graph break, the decode path gets # baked into a compiled piece; prefill batches then replay the wrong path → -# illegal memory access. The stochastic dispatch op is only needed when -# FULL graphs are disabled (in FULL mode, the active mixer's kernels get baked -# into the full graph transparently). +# illegal memory access. try: from vllm.config.compilation import CompilationConfig, CUDAGraphMode @@ -1376,11 +991,10 @@ def stochastic_mixer_dispatch_fake( CompilationConfig._attention_ops.append(_gdn_op) logger.info(f"Added {_gdn_op} to vLLM splitting ops") - if not APRIEL2_FULL_CUDA_GRAPHS: - _stochastic_op = "vllm::stochastic_mixer_dispatch" - if _stochastic_op not in CompilationConfig._attention_ops: - CompilationConfig._attention_ops.append(_stochastic_op) - logger.info(f"Added {_stochastic_op} to vLLM splitting ops") + _stochastic_op = "vllm::stochastic_mixer_dispatch" + if _stochastic_op not in CompilationConfig._attention_ops: + CompilationConfig._attention_ops.append(_stochastic_op) + logger.info(f"Added {_stochastic_op} to vLLM splitting ops") except ImportError: logger.warning("Could not add custom ops to vLLM splitting ops") @@ -1906,17 +1520,6 @@ def _forward_core( non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor num_actual_tokens = attn_metadata.num_actual_tokens - if DEBUG_STATE_INDICES and not _compile_monitor.cudagraph_capturing_enabled: - indices = non_spec_state_indices_tensor[:num_actual_tokens].tolist() - unique = {i for i in indices if i >= 0} - dup = len(indices) - len(unique) - indices.count(-1) - tag = " ** DUPLICATE **" if dup > 0 else "" - print( - f"[STATE-IDX GDN {self.prefix}] " - f"decodes={attn_metadata.num_decodes} prefills={attn_metadata.num_prefills} " - f"indices={indices}{tag}" - ) - # self._debug_print(f"num_actual_tokens={num_actual_tokens}, num_prefills={attn_metadata.num_prefills}, num_decodes={attn_metadata.num_decodes}") # self._debug_print(f"has_initial_state={has_initial_state}") # self._debug_print(f"non_spec_query_start_loc={non_spec_query_start_loc}") @@ -2081,21 +1684,9 @@ def _forward_core( print(f"[SYNC-5] {self.prefix}: post-chunk_gated_delta_rule sync OK", flush=True) # self._debug_tensor("core_out (from chunk_gated_delta_rule)", core_out) # self._debug_tensor("last_state", last_state) - # # Debug prefill state - get seq_len from query_start_loc - # if non_spec_query_start_loc is not None and len(non_spec_query_start_loc) >= 2: - # prefill_seq_len = int(non_spec_query_start_loc[1] - non_spec_query_start_loc[0]) - # else: - # prefill_seq_len = num_actual_tokens - # self._debug_state_stats("PREFILL out_state", last_state, prefill_seq_len) ssm_state[non_spec_state_indices_tensor] = last_state.to(ssm_state.dtype) else: # self._debug_print("Using fused_recurrent_gated_delta_rule (decode)") - # # For decode, access the correct slot using state indices - # if non_spec_state_indices_tensor is not None and len(non_spec_state_indices_tensor) > 0: - # slot_idx = int(non_spec_state_indices_tensor[0]) - # actual_state = ssm_state[slot_idx:slot_idx+1] - # # self._debug_state_stats("DECODE in_state", actual_state, num_actual_tokens) - # Debug decode inputs if DEBUG_GDN_STATE: print( f"[vLLM-GDN {self.prefix}] DECODE inputs: q={query.flatten()[:4].tolist()}, k={key.flatten()[:4].tolist()}, v={value.flatten()[:4].tolist()}, g={g.flatten()[:4].tolist()}, beta={beta.flatten()[:4].tolist()}" @@ -2429,17 +2020,6 @@ def _forward( if _compile_monitor.cudagraph_capturing_enabled and num_actual_tokens > num_cache_lines: num_actual_tokens = num_cache_lines - if DEBUG_STATE_INDICES and not _compile_monitor.cudagraph_capturing_enabled: - indices = non_spec_state_indices_tensor[:num_actual_tokens].tolist() - unique = {i for i in indices if i >= 0} - dup = len(indices) - len(unique) - indices.count(-1) - tag = " ** DUPLICATE **" if dup > 0 else "" - print( - f"[STATE-IDX KDA {self.prefix}] " - f"decodes={attn_metadata.num_decodes} prefills={attn_metadata.num_prefills} " - f"indices={indices}{tag}" - ) - q_proj_states = q_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens] v_proj_states = v_proj_states[:num_actual_tokens] @@ -3009,19 +2589,13 @@ def __init__( raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - if APRIEL2_FULL_CUDA_GRAPHS: - # FULL mode: the entire forward (including dispatch) is captured as - # one CUDA graph. On layout change, graphs are re-captured. - # No per-mixer cache needed — vLLM's CUDAGraphWrapper handles it. - self._mixer_graph_cache: MixerGraphCache | None = None - else: - # PIECEWISE mode: force graph break at dispatch, mixer runs eagerly - _force_piecewise_cudagraph_for_stochastic_mixers() - self._mixer_graph_cache = MixerGraphCache() if APRIEL2_MIXER_CUDA_GRAPHS else None + # Force PIECEWISE cudagraph mode for stochastic mixers + # FULL mode captures only active mixer ops, breaking dormant mixer switching + _force_piecewise_cudagraph_for_stochastic_mixers() def set_active_mixer(self, name: str) -> None: """Set the active mixer by name.""" - if name not in self._mixer_names: + if name not in self.mixers: raise ValueError(f"Unknown mixer '{name}'. Available: {self._mixer_names}") self.active_mixer_name = name @@ -3029,29 +2603,6 @@ def get_active_mixer(self) -> str: """Get the name of the currently active mixer.""" return self.active_mixer_name - def offload_inactive_mixers(self) -> int: - """Move inactive mixer weights to CPU. Returns bytes freed. - - Also removes offloaded mixers from self.mixers (nn.ModuleDict) and - stores them in self._offloaded_mixers (plain dict). This is critical: - torch.compile/dynamo sets guards on every parameter it sees in the - module tree. If offloaded params stay in the tree with device='cpu', - a subsequent forward triggers guard failure → re-trace → crash. - Hiding them from the module tree avoids the issue entirely. - """ - freed = 0 - device_cpu = torch.device("cpu") - self._offloaded_mixers: dict[str, nn.Module] = {} - to_offload = [name for name in self.mixers if name != self.active_mixer_name] - for name in to_offload: - mixer = self.mixers[name] - for param in mixer.parameters(): - freed += param.data.nbytes - _move_module_device(mixer, device_cpu) - self._offloaded_mixers[name] = mixer - del self.mixers[name] - return freed - def forward( self, hidden_states: torch.Tensor, @@ -3391,77 +2942,13 @@ class Apriel2ForCausalLM(nn.Module, HasInnerState, SupportsPP): }, ) - # has_inner_state = True: model uses recurrent state (GDN, KDA, Mamba). - # is_hybrid = False: we handle config verification ourselves via - # MODELS_CONFIG_MAP (see config_convertor.py) rather than relying on - # vLLM's default HybridAttentionMambaModelConfig dispatch. This is - # necessary because the default dispatch crashes for pure-mamba models - # (ZeroDivisionError: num_kv_heads=0 when no attention blocks exist). + # For hybrid models has_inner_state = True + # Don't use is_hybrid=True - it triggers HybridAttentionMambaModelConfig + # which assumes all mamba-like layers have the same shape. + # Apriel2 has heterogeneous blocks, each with its own get_kv_cache_spec(). is_hybrid = False - @classmethod - def get_mamba_state_shape_from_config( - cls, - vllm_config: VllmConfig, - ) -> tuple: - """Return the largest mamba state shape across all mixer types. - - HybridAttentionMambaModelConfig.verify_and_update_config() calls this - to compute page size alignment. It only needs a conservative upper - bound — per-layer get_kv_cache_spec() handles actual heterogeneous - allocation. We return the shape of whichever mixer type produces the - largest page_size_bytes (the "envelope"). - """ - config = vllm_config.model_config.hf_config - decoder_config = getattr(config, "decoder", {}) or {} - blocks_config = get_blocks_config(decoder_config) - block_params = get_block_params(blocks_config, vllm_config) - - # Find the mamba block with the largest natural page size - best_shapes: tuple | None = None - best_page_size = 0 - for params in block_params.values(): - if isinstance(params, MambaBlockParams): - if params.natural_page_size > best_page_size: - best_page_size = params.natural_page_size - best_shapes = params.shapes - - if best_shapes is None: - # Pure attention model — return minimal shapes so - # verify_and_update_config() sees mamba_page_size=0 and returns. - return ((1, 1), (1, 1)) - - return best_shapes - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: VllmConfig, - ) -> tuple: - """Return dtypes matching the envelope mamba state shape. - - Must be consistent with get_mamba_state_shape_from_config() — returns - the dtypes of whichever mixer type has the largest page size. - """ - config = vllm_config.model_config.hf_config - decoder_config = getattr(config, "decoder", {}) or {} - blocks_config = get_blocks_config(decoder_config) - block_params = get_block_params(blocks_config, vllm_config) - - best_dtypes: tuple | None = None - best_page_size = 0 - for params in block_params.values(): - if isinstance(params, MambaBlockParams): - if params.natural_page_size > best_page_size: - best_page_size = params.natural_page_size - best_dtypes = params.dtypes - - if best_dtypes is None: - return (torch.float32, torch.float32) - - return best_dtypes - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() # Install triton allocator lazily - this runs in the vLLM subprocess @@ -3558,28 +3045,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def offload_inactive_mixers(self) -> int: - """Offload all inactive mixer weights to CPU. - - Cannot run inside load_weights() because torch.compile captures ALL - parameters as compiled graph inputs — profile_run() would crash trying - to pass CPU tensors to the GPU graph. Instead, this is called from the - monkey-patched Worker.determine_available_memory() AFTER profile_run(). - - Returns: - Total bytes freed on GPU. - """ - total_freed = 0 - for layer in self.model.layers: - if isinstance(layer, Apriel2StochasticDecoderLayer): - total_freed += layer.mixer.offload_inactive_mixers() - - torch.cuda.synchronize() - torch.cuda.empty_cache() - freed_gib = total_freed / (1024**3) - logger.info(f"Offloaded inactive mixer weights to CPU: freed {freed_gib:.2f} GiB GPU memory") - return total_freed - def set_layer_placements(self, placement: list[str]) -> dict[int, str]: """Set the active mixer for each stochastic layer. @@ -3687,136 +3152,15 @@ def _clear_kv_cache(self) -> None: logger.info("Cleared KV cache tensors for placement switch") - def _clear_piecewise_wrappers(module: nn.Module) -> None: - """Recursively clear all piecewise CUDAGraphWrapper entries.""" - from vllm.compilation.cuda_graph import CUDAGraphWrapper - - for val in module.__dict__.values(): - if isinstance(val, CUDAGraphWrapper): - val.concrete_cudagraph_entries.clear() - elif isinstance(val, nn.Module): - _clear_piecewise_wrappers(val) - for child in module.children(): - _clear_piecewise_wrappers(child) - - def _recapture_cuda_graphs(worker) -> None: - """Invalidate and re-capture CUDA graphs after layout change. - - In FULL CUDA graph mode, the captured graphs contain GPU kernels for - the previous layout. After changing mixer assignments, we must - re-capture to bake in the new layout's kernels. - """ - model_runner = getattr(worker, "model_runner", None) - if model_runner is None: - return - - from vllm.compilation.cuda_graph import CUDAGraphWrapper - - # 1. Clear outer FULL wrapper entries - model = model_runner.model - if isinstance(model, CUDAGraphWrapper): - num_cleared = len(model.concrete_cudagraph_entries) - model.concrete_cudagraph_entries.clear() - logger.info(f"Cleared {num_cleared} FULL CUDA graph entries") - - # 2. Clear inner piecewise wrapper entries (for FULL_AND_PIECEWISE mode) - inner_model = model.unwrap() if isinstance(model, CUDAGraphWrapper) else model - _clear_piecewise_wrappers(inner_model) - - # 3. Re-capture for all batch sizes - logger.info("Re-capturing CUDA graphs for new layout...") - model_runner.capture_model() - logger.info("CUDA graph re-capture complete") - - def _swap_mixer_weights(worker, placement: list[str]) -> None: - """Swap mixer weights between GPU and CPU for placement change. - - For each layer where the active mixer changes: - 1. Offload old active → CPU, remove from ModuleDict, store in _offloaded_mixers - 2. Load new active from _offloaded_mixers → GPU, add to ModuleDict - Done layer by layer to avoid transient OOM. - """ - model = worker.get_model() - device_gpu = torch.device("cuda") - device_cpu = torch.device("cpu") - loaded = 0 - offloaded = 0 - - for layer_idx, new_mixer_name in enumerate(placement): - if layer_idx >= len(model.model.layers): - break - layer = model.model.layers[layer_idx] - if not isinstance(layer, Apriel2StochasticDecoderLayer): - continue - - stochastic = layer.mixer - old_mixer_name = stochastic.active_mixer_name - if old_mixer_name == new_mixer_name: - continue - - # 1. Offload old active → CPU, hide from module tree - old_mixer = stochastic.mixers[old_mixer_name] - _move_module_device(old_mixer, device_cpu) - del stochastic.mixers[old_mixer_name] - stochastic._offloaded_mixers[old_mixer_name] = old_mixer - offloaded += 1 - - # 2. Load new active from offloaded → GPU, restore to module tree - new_mixer = stochastic._offloaded_mixers.pop(new_mixer_name) - _move_module_device(new_mixer, device_gpu) - stochastic.mixers[new_mixer_name] = new_mixer - loaded += 1 - - if loaded or offloaded: - torch.cuda.synchronize() - torch.cuda.empty_cache() - logger.info(f"Weight swap: offloaded {offloaded} mixers to CPU, " f"loaded {loaded} mixers to GPU") - def _set_layer_placements(self, placement: list[str]) -> dict[int, str]: # Clear KV cache BEFORE changing placement to prevent reading stale data # written by a different mixer type (which could cause NaN errors) _clear_kv_cache(self) - - # Swap weights before changing active mixer (needs old active_mixer_name) - if APRIEL2_OFFLOAD_INACTIVE_MIXERS: - _swap_mixer_weights(self, placement) - - result = self.get_model().set_layer_placements(placement) - # Re-capture CUDA graphs with the new layout baked in - if APRIEL2_FULL_CUDA_GRAPHS and result: - _recapture_cuda_graphs(self) - return result + return self.get_model().set_layer_placements(placement) def _get_mixer_names(self) -> tuple[str, ...]: return self.get_model().get_mixer_names() - # -- Weight offloading: patch determine_available_memory --------------- - # torch.compile captures ALL parameters as graph inputs. If we offload - # during load_weights(), profile_run() crashes (CPU tensors in GPU graph). - # Instead, offload AFTER profile_run() but adjust available memory upward. - - if APRIEL2_OFFLOAD_INACTIVE_MIXERS: - _orig_determine_available_memory = Worker.determine_available_memory - - @torch.inference_mode() - def _determine_available_memory_with_offload(self) -> int: - result = _orig_determine_available_memory(self) - - # Offload inactive mixer weights to CPU - freed = self.get_model().offload_inactive_mixers() - if freed > 0: - self.available_kv_cache_memory_bytes += freed - self.model_runner.model_memory_usage -= freed - result = int(self.available_kv_cache_memory_bytes) - logger.info( - "Adjusted available KV cache memory: +%.2f GiB from weight offloading", - freed / (1024**3), - ) - - return result - - Worker.determine_available_memory = _determine_available_memory_with_offload - Worker.get_layer_placements = _get_layer_placements Worker.set_layer_placements = _set_layer_placements Worker.get_mixer_names = _get_mixer_names From 97b84d5855e07ddac3a28c939c64b61ea28489b2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 9 Apr 2026 20:20:45 +0000 Subject: [PATCH 10/13] Remove debug artifacts from modeling_apriel2 Remove all debug-only code that accumulated during development: - DEBUG_* flag definitions (DEBUG_GDN_LAYER, DEBUG_GDN_STATE, DEBUG_GDN_OUTPUT, DEBUG_KDA_LAYER, DEBUG_DECODER_LAYER, DEBUG_FINAL_NORM, DEBUG_LM_HEAD, DEBUG_SYNC) - _debug_print, _debug_tensor, _debug_state_stats helper methods - All if DEBUG_*: blocks with print statements - All commented-out # self._debug_* / # self._cached_* lines - SYNC-1..6 CUDA sync debug points - import os (only used by DEBUG_SYNC) Functional code and bug fixes from previous commits are unchanged. Co-Authored-By: Claude Sonnet 4.6 --- .../apriel2/vllm/modeling_apriel2.py | 290 +----------------- 1 file changed, 1 insertion(+), 289 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 3dd822603..20f07d39f 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -9,7 +9,6 @@ import logging import math -import os from collections.abc import Iterable from dataclasses import dataclass from itertools import islice @@ -157,23 +156,6 @@ def _patched(kv_cache_spec: dict) -> list: _patch_kv_cache_grouping() -# ============================================================================= -# Debug Flags -# ============================================================================= -# Top-level debug flags that control all debug output in the module. -# Set these to True to enable debugging for specific components. - -DEBUG_GDN_LAYER = False # Debug GDN layer forward pass (tensors, shapes) -DEBUG_GDN_STATE = False # Debug GDN recurrent state during decode -DEBUG_GDN_OUTPUT = False # Debug GDN output hidden states during decode -DEBUG_KDA_LAYER = False # Debug KDA layer outputs -DEBUG_DECODER_LAYER = False # Debug decoder layer outputs (residual, norm) -DEBUG_FINAL_NORM = False # Debug final norm before LM head -DEBUG_LM_HEAD = False # Debug LM head input/output -# Sync CUDA before/between GDN kernels to catch the exact source of async -# illegal-memory-access errors. Very slow — only for debugging. -DEBUG_SYNC = os.environ.get("APRIEL2_DEBUG_SYNC", "0") == "1" - # ============================================================================= # KV Cache Spec Computation @@ -1351,38 +1333,6 @@ def rearrange_mixed_qkv(self, mixed_qkv: torch.Tensor | None): value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query.contiguous(), key.contiguous(), value.contiguous() - def _debug_state_stats(self, name: str, state: torch.Tensor, seq_len: int): - """Debug recurrent state with statistics.""" - if not DEBUG_GDN_STATE or state is None: - return - flat = state.flatten() - first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) - print( - f"[vLLM-GDN {self.prefix}] {name} (seq_len={seq_len}): shape={state.shape}, " - f"mean={state.float().mean().item():.6f}, std={state.float().std().item():.6f}, " - f"min={state.float().min().item():.6f}, max={state.float().max().item():.6f}, " - f"first8=[{first8}]" - ) - - def _debug_tensor(self, name: str, t: torch.Tensor): - if not DEBUG_GDN_LAYER: - return - if t is None: - print(f"[GDN {self.prefix}] {name}: None") - return - flat = t.flatten()[:8] - vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) - print( - f"[GDN {self.prefix}] {name}: shape={t.shape}, dtype={t.dtype}, " - f"mean={t.float().mean().item():.6f}, std={t.float().std().item():.6f}, " - f"first8=[{vals}]" - ) - - def _debug_print(self, msg: str): - if not DEBUG_GDN_LAYER: - return - print(f"[GDN {self.prefix}] {msg}") - def forward( self, hidden_states: torch.Tensor, @@ -1393,30 +1343,17 @@ def forward( """Forward pass with custom op for core attention.""" num_tokens = hidden_states.size(0) - # self._cached_hidden_states = hidden_states # Cache for debug in _forward_core - # self._debug_print(f"===== FORWARD START (num_tokens={num_tokens}) =====") - # self._debug_tensor("hidden_states", hidden_states) - # Part 1: Input Projection projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_ba, _ = self.in_proj_ba(hidden_states) - # self._debug_tensor("projected_states_qkvz", projected_states_qkvz) - # self._debug_tensor("projected_states_ba", projected_states_ba) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) - # self._debug_tensor("query (after fix_ordering)", query) - # self._debug_tensor("key (after fix_ordering)", key) - # self._debug_tensor("value (after fix_ordering)", value) - # self._debug_tensor("z (after fix_ordering)", z) - # self._debug_tensor("b (after fix_ordering)", b) - # self._debug_tensor("a (after fix_ordering)", a) # Flatten heads: [tokens, heads, head_dim] -> [tokens, heads * head_dim] query = query.reshape(query.size(0), -1) key = key.reshape(key.size(0), -1) value = value.reshape(value.size(0), -1) mixed_qkv = torch.cat((query, key, value), dim=-1) - # self._debug_tensor("mixed_qkv (flattened)", mixed_qkv) # Part 2: Core Attention (Custom Op) core_attn_out = torch.zeros( @@ -1425,10 +1362,6 @@ def forward( device=hidden_states.device, ) - if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: - torch.cuda.synchronize() - print(f"[SYNC-1] {self.prefix}: pre-custom-op sync OK", flush=True) - torch.ops.vllm.apriel2_gdn_attention_core( mixed_qkv, b, @@ -1436,58 +1369,19 @@ def forward( core_attn_out, self.prefix, ) - # self._debug_tensor("core_attn_out (after custom op)", core_attn_out) # Part 3: Output Projection z_shape_og = z.shape core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - # self._debug_tensor("core_attn_out (before norm)", core_attn_out) - # self._debug_tensor("z (before norm)", z) - # Debug last token before norm (reshaped has tokens * heads rows) - if DEBUG_GDN_LAYER and num_tokens > 0: - num_heads = self.num_v_heads // self.tp_size - last_token_start = (num_tokens - 1) * num_heads - last_attn = core_attn_out[last_token_start : last_token_start + 1, :8] - last_z = z[last_token_start : last_token_start + 1, :8] - print( - f"[GDN {self.prefix}] core_attn_out before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_attn.flatten().float().tolist())}]" - ) - print( - f"[GDN {self.prefix}] z before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_z.flatten().float().tolist())}]" - ) - # self._debug_tensor("norm.weight", self.norm.weight) - # self._debug_print(f"norm.norm_before_gate={self.norm.norm_before_gate}, norm.eps={self.norm.eps}") core_attn_out = self.norm(core_attn_out, z) - # self._debug_tensor("core_attn_out (after norm)", core_attn_out) - # Debug last token after norm - if DEBUG_GDN_LAYER and num_tokens > 0: - last_attn_after = core_attn_out[last_token_start : last_token_start + 1, :8] - print( - f"[GDN {self.prefix}] core_attn_out after norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_attn_after.flatten().float().tolist())}]" - ) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") # Align dtype with projection weights (FA kernels may yield float32) target_dtype = self.out_proj.weight.dtype if core_attn_out.dtype != target_dtype: core_attn_out = core_attn_out.to(target_dtype) - # self._debug_tensor("core_attn_out (before out_proj)", core_attn_out) output[:num_tokens], _ = self.out_proj(core_attn_out) - # self._debug_tensor("output (final)", output[:num_tokens]) - # Show last token specifically - if DEBUG_GDN_LAYER: - last_token = output[num_tokens - 1, :8] - vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist()) - print(f"[GDN {self.prefix}] output (last token): last_token_first8=[{vals}]") - # Debug output hidden states during decode (num_tokens == 1) - if DEBUG_GDN_OUTPUT and num_tokens == 1: - flat = output[:num_tokens].flatten() - first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) - print( - f"[vLLM-GDN {self.prefix}] OUTPUT hs: shape={output[:num_tokens].shape}, mean={output[:num_tokens].float().mean().item():.6f}, std={output[:num_tokens].float().std().item():.6f}, first8=[{first8}]" - ) - # self._debug_print("===== FORWARD END =====") def _forward_core( self, @@ -1497,16 +1391,11 @@ def _forward_core( core_attn_out: torch.Tensor, ): """Core attention computation (called by custom op).""" - # self._debug_print("===== _forward_core START =====") - # self._debug_tensor("mixed_qkv (input to core)", mixed_qkv) - # self._debug_tensor("b (input to core)", b) - # self._debug_tensor("a (input to core)", a) forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: - # self._debug_print("attn_metadata is None, returning early") return assert isinstance(attn_metadata, dict) @@ -1520,10 +1409,6 @@ def _forward_core( non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor num_actual_tokens = attn_metadata.num_actual_tokens - # self._debug_print(f"num_actual_tokens={num_actual_tokens}, num_prefills={attn_metadata.num_prefills}, num_decodes={attn_metadata.num_decodes}") - # self._debug_print(f"has_initial_state={has_initial_state}") - # self._debug_print(f"non_spec_query_start_loc={non_spec_query_start_loc}") - self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] @@ -1547,38 +1432,12 @@ def _forward_core( b = b[:num_actual_tokens] a = a[:num_actual_tokens] - # self._debug_tensor("mixed_qkv (truncated)", mixed_qkv) - # self._debug_tensor("b (truncated)", b) - # self._debug_tensor("a (truncated)", a) - # Convolution conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - # self._debug_tensor("conv_weights", conv_weights) - # self._debug_tensor("conv1d.bias", self.conv1d.bias) - # self._debug_print(f"activation={self.activation}") if attn_metadata.num_prefills > 0: - # self._debug_print("Using causal_conv1d_fn (prefill path)") mixed_qkv_T = mixed_qkv.transpose(0, 1) - # self._debug_tensor("mixed_qkv_T (before conv)", mixed_qkv_T) - - if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: - max_ci = non_spec_state_indices_tensor.max().item() - print( - f"[PRE-CONV] {self.prefix}: conv_state.shape={conv_state.shape}, " - f"x.shape={mixed_qkv_T.shape}, " - f"cache_indices={non_spec_state_indices_tensor.tolist()} (max={max_ci}), " - f"num_cache_lines={num_cache_lines}, " - f"query_start_loc={non_spec_query_start_loc.tolist()}, " - f"has_initial_state={has_initial_state.tolist()}, " - f"num_prefills={attn_metadata.num_prefills}", - flush=True, - ) - if max_ci >= num_cache_lines: - print( - f"[OOB-CONV] {self.prefix}: cache_index {max_ci} >= num_cache_lines {num_cache_lines}!", - flush=True, - ) + mixed_qkv = causal_conv1d_fn( mixed_qkv_T, conv_weights, @@ -1590,11 +1449,7 @@ def _forward_core( query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) - if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: - torch.cuda.synchronize() - print(f"[SYNC-2] {self.prefix}: post-causal_conv1d_fn sync OK", flush=True) else: - # self._debug_print("Using causal_conv1d_update (decode path)") causal_conv1d_update( mixed_qkv, conv_state, @@ -1605,68 +1460,21 @@ def _forward_core( validate_data=True, ) - # self._debug_tensor("mixed_qkv (after conv)", mixed_qkv) - query, key, value = self.rearrange_mixed_qkv(mixed_qkv) - # self._debug_tensor("query (after rearrange)", query) - # self._debug_tensor("key (after rearrange)", key) - # self._debug_tensor("value (after rearrange)", value) # Expand K heads to V heads for grouped query attention # (matches Fast-LLM and transformers reference implementations) # Always call repeat_interleave (no-op when value_heads_per_key == 1) to avoid # conditional branches that confuse torch.compile - # self._debug_print(f"Expanding K heads to V heads (value_heads_per_key={self.value_heads_per_key})") query = query.repeat_interleave(self.value_heads_per_key, dim=2) key = key.repeat_interleave(self.value_heads_per_key, dim=2) - # self._debug_tensor("query (after expand)", query) - # self._debug_tensor("key (after expand)", key) - - # self._debug_tensor("A_log", self.A_log) - # self._debug_tensor("dt_bias", self.dt_bias) g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: - torch.cuda.synchronize() - print(f"[SYNC-3] {self.prefix}: post-fused_gdn_gating sync OK", flush=True) - # self._debug_tensor("g (from gating)", g) - # self._debug_tensor("beta (from gating)", beta) # Recurrent attention if attn_metadata.num_prefills > 0: - # self._debug_print("Using chunk_gated_delta_rule (prefill)") - if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: - max_idx = non_spec_state_indices_tensor.max().item() - print( - f"[SYNC-4] {self.prefix}: prefill ssm_state.shape={ssm_state.shape}, " - f"indices={non_spec_state_indices_tensor.tolist()}, " - f"max_idx={max_idx}, num_cache_lines={num_cache_lines}, " - f"has_initial_state={has_initial_state.tolist()}, " - f"num_prefills={attn_metadata.num_prefills}", - flush=True, - ) - if max_idx >= ssm_state.shape[0]: - print( - f"[OOB] {self.prefix}: max_idx={max_idx} >= ssm_state.shape[0]={ssm_state.shape[0]}!", - flush=True, - ) initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 - # self._debug_tensor("initial_state", initial_state) - # Debug PREFILL INPUTS before kernel call - if DEBUG_GDN_STATE: - print(f"[vLLM-GDN {self.prefix}] PREFILL INPUTS:") - print( - f" hidden_states: shape={self._cached_hidden_states.shape}, first8={self._cached_hidden_states.flatten()[:8].tolist()}" - ) - print(f" mixed_qkv (input): shape={mixed_qkv.shape}, first8={mixed_qkv.flatten()[:8].tolist()}") - print(f" q: shape={query.shape}, first8={query.flatten()[:8].tolist()}") - print(f" k: shape={key.shape}, first8={key.flatten()[:8].tolist()}") - print(f" v: shape={value.shape}, first8={value.flatten()[:8].tolist()}") - print(f" g: shape={g.shape}, first8={g.flatten()[:8].tolist()}") - print(f" beta: shape={beta.shape}, first8={beta.flatten()[:8].tolist()}") - print(f" initial_state: {initial_state}") - print(f" cu_seqlens: {non_spec_query_start_loc}") core_out, last_state = chunk_gated_delta_rule( q=query, k=key, @@ -1679,18 +1487,8 @@ def _forward_core( head_first=False, use_qk_l2norm_in_kernel=True, ) - if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: - torch.cuda.synchronize() - print(f"[SYNC-5] {self.prefix}: post-chunk_gated_delta_rule sync OK", flush=True) - # self._debug_tensor("core_out (from chunk_gated_delta_rule)", core_out) - # self._debug_tensor("last_state", last_state) ssm_state[non_spec_state_indices_tensor] = last_state.to(ssm_state.dtype) else: - # self._debug_print("Using fused_recurrent_gated_delta_rule (decode)") - if DEBUG_GDN_STATE: - print( - f"[vLLM-GDN {self.prefix}] DECODE inputs: q={query.flatten()[:4].tolist()}, k={key.flatten()[:4].tolist()}, v={value.flatten()[:4].tolist()}, g={g.flatten()[:4].tolist()}, beta={beta.flatten()[:4].tolist()}" - ) # num_actual_tokens already clamped to cache lines above num_decodes = min(attn_metadata.num_decodes, num_actual_tokens) core_out, _ = fused_recurrent_gated_delta_rule( @@ -1705,14 +1503,7 @@ def _forward_core( ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], use_qk_l2norm_in_kernel=True, ) - # self._debug_tensor("core_out (from fused_recurrent)", core_out) - # if non_spec_state_indices_tensor is not None and len(non_spec_state_indices_tensor) > 0: - # actual_state = ssm_state[slot_idx:slot_idx+1] - # # self._debug_state_stats("DECODE out_state", actual_state, num_actual_tokens) - core_attn_out[:num_actual_tokens] = core_out.squeeze(0)[:num_actual_tokens] - # self._debug_tensor("core_attn_out (final output)", core_attn_out) - # self._debug_print("===== _forward_core END =====") def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Checkpoint uses "convolution", model uses "conv1d" @@ -1966,9 +1757,6 @@ def forward( dtype=hidden_states.dtype, device=hidden_states.device, ) - if DEBUG_SYNC and not torch.compiler.is_compiling() and not _compile_monitor.cudagraph_capturing_enabled: - torch.cuda.synchronize() - print(f"[SYNC-6] {self.prefix}: pre-kda_attention sync OK", flush=True) torch.ops.vllm.kda_attention( q, @@ -2369,51 +2157,22 @@ def __init__( self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) - def _debug_tensor(self, name: str, t: torch.Tensor, show_last=False): - if not DEBUG_DECODER_LAYER or t is None: - return - if show_last: - # Show last token - last = t[-1, :8] if t.dim() == 2 else t[0, -1, :8] - vals = ", ".join(f"{v:.6f}" for v in last.float().tolist()) - print(f"[vLLM Layer] {name}: shape={t.shape}, last_token_first8=[{vals}]") - else: - flat = t.flatten()[:8] - vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) - print(f"[vLLM Layer] {name}: shape={t.shape}, first8=[{vals}]") - def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor | None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: - # self._debug_tensor("input hidden_states", hidden_states) - # self._debug_tensor("input residual", residual) - if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - # self._debug_tensor("after input_layernorm", hidden_states) - # self._debug_tensor("residual after input_layernorm", residual) - output = torch.empty_like(hidden_states) self.mixer(hidden_states, output) - # self._debug_tensor("mixer output", output) - hidden_states, residual = self.post_attention_layernorm(output, residual) - # self._debug_tensor("after post_attention_layernorm", hidden_states) - # self._debug_tensor("residual after post_attention_layernorm", residual) - hidden_states = self.mlp(hidden_states) - # self._debug_tensor("after mlp", hidden_states) - # Also show last token for final layer comparison - # self._debug_tensor("after mlp (last token)", hidden_states, show_last=True) - # self._debug_tensor("residual (last token)", residual, show_last=True) - return hidden_states, residual @@ -2900,33 +2659,7 @@ def forward( if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) - # Debug final norm - if DEBUG_FINAL_NORM: - # Show LAST token (to match TF) - last_hs = hidden_states[-1, :8] - last_res = residual[-1, :8] if residual is not None else None - hs_vals = ", ".join(f"{v:.6f}" for v in last_hs.float().tolist()) - res_vals = ", ".join(f"{v:.6f}" for v in last_res.float().tolist()) if last_res is not None else "None" - print( - f"[vLLM Final] hidden_states (before norm): shape={hidden_states.shape}, last_token_first8=[{hs_vals}]" - ) - print( - f"[vLLM Final] residual (before norm): shape={residual.shape if residual is not None else None}, last_token_first8=[{res_vals}]" - ) - print( - f"[vLLM Final] norm.weight: first8=[{', '.join(f'{v:.6f}' for v in self.norm.weight.flatten()[:8].float().tolist())}]" - ) - print(f"[vLLM Final] norm.variance_epsilon={self.norm.variance_epsilon}") - hidden_states, _ = self.norm(hidden_states, residual) - - if DEBUG_FINAL_NORM: - last_out = hidden_states[-1, :8] - out_vals = ", ".join(f"{v:.6f}" for v in last_out.float().tolist()) - print( - f"[vLLM Final] hidden_states (after norm): shape={hidden_states.shape}, last_token_first8=[{out_vals}]" - ) - return hidden_states @@ -3007,28 +2740,7 @@ def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: - # Debug LM head input - if DEBUG_LM_HEAD: - flat = hidden_states.flatten()[:8] - vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) - print(f"[vLLM LM Head] input hidden_states: shape={hidden_states.shape}, first8=[{vals}]") - if self.lm_head is not None: - lm_weight = self.lm_head.weight - print( - f"[vLLM LM Head] lm_head.weight: shape={lm_weight.shape}, first8=[{', '.join(f'{v:.6f}' for v in lm_weight.flatten()[:8].float().tolist())}]" - ) - logits = self.logits_processor(self.lm_head, hidden_states) - - if DEBUG_LM_HEAD and logits is not None: - # Get last token logits - last_logits = logits[-1] if logits.dim() == 2 else logits[0, -1] - top_vals, top_idx = last_logits.topk(5) - print(f"[vLLM LM Head] logits shape={logits.shape}") - print( - f"[vLLM LM Head] last token top-5 logits: {[(idx.item(), val.item()) for idx, val in zip(top_idx, top_vals)]}" - ) - return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: From 81c243ad29fcf4efb0668070bcc40965261a32e2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 9 Apr 2026 21:10:55 +0000 Subject: [PATCH 11/13] Remove ai_docs directory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These docs covered per-mixer CUDA graph caching and layer placement change procedures — both tied to the WIP infrastructure that was removed. Co-Authored-By: Claude Sonnet 4.6 --- .../vllm/ai_docs/layer_placement_change.md | 80 ------------- .../vllm/ai_docs/per_mixer_cuda_graphs.md | 105 ------------------ 2 files changed, 185 deletions(-) delete mode 100644 fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md delete mode 100644 fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md diff --git a/fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md b/fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md deleted file mode 100644 index ff4ae0335..000000000 --- a/fast_llm_external_models/apriel2/vllm/ai_docs/layer_placement_change.md +++ /dev/null @@ -1,80 +0,0 @@ -# Layer Placement Change in Apriel2 vLLM - -## Architecture Overview - -Apriel2 is a **hybrid model** with heterogeneous decoder layers. Some layers are "stochastic" — they contain **multiple mixer sub-modules** (e.g., attention, GDN, KDA) but only run **one at a time**. The placement system lets you dynamically switch which mixer is active per layer at runtime. - -## Key Components - -### 1. `Apriel2StochasticMixer` (line ~2344) - -- Contains an `nn.ModuleDict` called `self.mixers` with all sub-mixer instances (e.g., attention, GDN, KDA) -- Tracks `self.active_mixer_name` — which mixer is currently active -- All sub-mixers have their weights loaded, but only one runs during forward pass -- Each sub-mixer gets a **virtual layer index** (`layer_idx + (mixer_index+1) * num_layers`) so they each get separate KV cache allocations without collisions - -### 2. `Apriel2StochasticDecoderLayer` (line ~2513) - -- Wraps `Apriel2StochasticMixer` + MLP + layer norms -- Exposes `set_active_mixer(name)` / `get_active_mixer()` which delegate to the mixer - -### 3. Dynamic dispatch via custom op (line ~870) - -- `stochastic_mixer_dispatch` is registered as a `vllm::stochastic_mixer_dispatch` custom op -- This op is added to vLLM's `_attention_ops` splitting ops list, causing **graph breaks** in torch.compile -- At runtime, it looks up the `Apriel2StochasticMixer` from `forward_context.no_compile_layers[layer_name]`, reads `active_mixer_name`, and forwards to that mixer -- The fake impl just copies input→output to satisfy the compiler's data dependency analysis - -## The Placement Change Call Chain - -From the debug script (`debug_offline.py`): - -```python -llm.collective_rpc("set_layer_placements", args=(placement,)) -``` - -1. **Worker monkey-patch** (line ~2962): `_patch_worker_for_placement_switching()` runs at import time and adds `set_layer_placements`/`get_layer_placements`/`get_mixer_names` methods to `vllm.v1.worker.gpu_worker.Worker` - -2. **`Worker._set_layer_placements`** (line ~3003): - - First calls `_clear_kv_cache(self)` — zeroes out **all** KV cache tensors to prevent stale data from a different mixer type causing NaN errors - - Then calls `self.get_model().set_layer_placements(placement)` - -3. **`Apriel2ForCausalLM.set_layer_placements`** (line ~2896): - - Iterates through all layers - - For each layer that is an `Apriel2StochasticDecoderLayer`, calls `layer.set_active_mixer(mixer_name)` with the corresponding entry from the placement list - -4. **`Apriel2StochasticMixer.set_active_mixer`** (line ~2454): - - Simply sets `self.active_mixer_name = name` (after validation) - -5. On the **next `llm.generate()` call**, the forward pass hits `stochastic_mixer_dispatch` which reads the updated `active_mixer_name` and routes to the new mixer. - -## Summary Diagram - -``` -debug_offline.py - | - +-- llm.collective_rpc("get_mixer_names") - | -> Worker.get_mixer_names -> model.get_mixer_names - | -> returns ("attention", "gdn", ...) from first stochastic layer - | - +-- llm.collective_rpc("get_layer_placements") - | -> Worker.get_layer_placements -> model.get_layer_placements - | -> returns {layer_idx: active_mixer_name} for stochastic layers - | - +-- llm.collective_rpc("set_layer_placements", args=(placement,)) - | -> Worker._set_layer_placements - | +-- _clear_kv_cache() <- zero all cache tensors - | +-- model.set_layer_placements(placement) - | +-- for each stochastic layer: - | layer.mixer.active_mixer_name = new_name - | - +-- llm.generate(prompts, ...) - -> forward pass per layer: - -> stochastic_mixer_dispatch (custom op, graph break) - -> looks up self.active_mixer_name - -> calls active_mixer.forward(hidden_states, output, positions) -``` - -## Key Insight - -All mixer weights are **always loaded** — switching is just flipping `active_mixer_name` and clearing the cache. The custom op mechanism ensures this dynamic routing works even with torch.compile/CUDA graphs by forcing graph breaks at dispatch points. diff --git a/fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md b/fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md deleted file mode 100644 index f90327565..000000000 --- a/fast_llm_external_models/apriel2/vllm/ai_docs/per_mixer_cuda_graphs.md +++ /dev/null @@ -1,105 +0,0 @@ -# Per-Mixer CUDA Graph Caching for Stochastic Mixers - -## Problem - -Apriel2's supernet uses `Apriel2StochasticMixer` — a wrapper that routes each layer to one of several mixer types (attention, GDN, KDA, sliding_window) based on the active placement. The `stochastic_mixer_dispatch` custom op is registered as a CUDA graph splitting op, forcing vLLM into **PIECEWISE** mode. - -In PIECEWISE mode, vLLM captures CUDA graphs for the compute pieces between split points (norms, MLPs) but runs the split points themselves eagerly. This means every mixer forward at every layer incurs full kernel launch overhead on every decode step. - -**Measured impact**: PIECEWISE mode achieves ~290 tok/s, while FULL CUDA graph mode with a fixed layout (no supernet) achieves ~900 tok/s — a **3x gap** from the split points running eagerly. - -## Approach - -Cache a separate `torch.cuda.CUDAGraph` per (mixer_name, num_tokens) at each stochastic layer. During decode, replay the cached graph instead of running the mixer eagerly. - -### Implementation (in `modeling_apriel2.py`) - -- **`APRIEL2_MIXER_CUDA_GRAPHS`** env var — gates the feature (default `"0"`) -- **`MixerGraphEntry`** — dataclass holding a captured graph + input/output pointer addresses -- **`MixerGraphCache`** — per-layer cache keyed by `(mixer_name, num_tokens)` -- **`_capture_all_mixers_for_num_tokens()`** — captures graphs during `capture_model()` with eager warmup before each capture (for Triton autotuning) -- **`_batch_has_prefill()`** — detects mixed prefill-decode batches that can't use graph replay -- **`stochastic_mixer_dispatch`** modified with capture/replay/eager-fallback logic -- Cache instance stored as `Apriel2StochasticMixer._mixer_graph_cache` - -### Dispatch Flow - -```text -stochastic_mixer_dispatch(hidden_states, output, positions, layer_name): - if cache is not None: - if capturing and runtime_mode == PIECEWISE: - → capture graphs for all/active mixers, return - if not prefill_batch and cache.has(active_mixer, num_tokens): - → cache.replay(), return - → eager fallback: active_mixer(hidden_states, output, positions) -``` - -## Bugs Encountered & Fixed - -### 1. CUBLAS_STATUS_NOT_INITIALIZED during profile_run - -`cudagraph_capturing_enabled` defaults to `True` in `vllm.compilation.monitor`. During `profile_run()` (before `capture_model()`), our code tried to capture graphs, but cuBLAS wasn't initialized yet. - -**Fix**: Gate capture on `runtime_mode != CUDAGraphMode.NONE` (NONE during profile_run, PIECEWISE during capture_model). - -### 2. Triton autotuning inside graph capture - -KDA's `fused_kda_gate` uses `@triton.autotune`. First call triggers benchmarking with `cuda.synchronize()` — illegal during stream capture. - -**Fix**: Run each mixer eagerly once before capturing (warmup triggers autotuning outside capture context). - -### 3. GPU memory pressure from too many captured graphs (CRITICAL) - -Capturing graphs for all mixers at all batch sizes creates ~5,040 graphs (48 layers x 3 mixers x ~35 batch sizes). This causes a **2.2x throughput regression** regardless of whether graphs are replayed. - -## Memory Pressure Investigation - -Systematic isolation of the regression source: - -| Test Configuration | Graphs | Warmup tok/s | vs Baseline | -| ------------------------------------------- | ------ | ------------ | ----------- | -| `CUDA_GRAPHS=0` (baseline) | 0 | 290 | 1.0x | -| `CUDA_GRAPHS=1`, cache exists but empty | 0 | 290 | 1.0x | -| `CUDA_GRAPHS=1`, active mixer only captured | ~1,680 | 179 | 0.62x | -| `CUDA_GRAPHS=1`, capture only (no replay) | ~5,040 | 132 | 0.46x | -| `CUDA_GRAPHS=1`, capture + replay | ~5,040 | 125 | 0.43x | -| `CUDA_GRAPHS=1`, private pool + replay | ~5,040 | 126 | 0.43x | - -**Key findings**: - -1. **Python overhead is negligible** — empty cache has zero impact (290 tok/s) -2. **Graph replay adds ~5% cost** — minimal compared to the capture overhead -3. **Private graph pool doesn't help** — total GPU memory consumption is the issue, not fragmentation of vLLM's global pool -4. **Regression scales with graph count** — 1,680 graphs = 0.62x, 5,040 = 0.43x -5. The captured graphs consume GPU memory that degrades all inference operations (likely L2 cache pressure, TLB misses, or reduced memory for temporary allocations) - -## Current State - -The implementation is functionally correct but the "capture everything upfront" strategy is not viable due to memory pressure. The code remains in `modeling_apriel2.py` gated behind `APRIEL2_MIXER_CUDA_GRAPHS=1` (disabled by default). - -## Proposed Next Approach: Lazy Per-Placement Capture - -Instead of capturing all mixers for all batch sizes during `capture_model()`: - -1. **On placement set**: capture graphs only for the active mixer at each layer, only for batch sizes actually encountered -2. **On placement change**: invalidate old cache (free GPU memory), re-capture for the new placement -3. **Lazy batch sizes**: capture on first encounter of a new batch size during decode, not upfront for all 35 sizes - -This would keep the graph count to ~48 (one per layer per active batch size), well within the safe memory budget. - -### Open Questions - -- **TP > 1 compatibility**: NCCL must be in graph-safe mode for captures involving collective ops. During `capture_model()` this is guaranteed; during inference it is not. Lazy capture may only be safe at TP=1. -- **Capture-during-inference feasibility**: Need to verify that `torch.cuda.graph()` capture works correctly when called from a piecewise split point during normal inference (not during `capture_model()`). -- **Warmup cost**: Each lazy capture requires an eager warmup (for Triton autotuning) + the capture itself. This adds latency to the first decode step after a placement change or new batch size. - -## Reference: vLLM Startup Phases - -```text -load_weights → profile_run() → allocate KV cache → capture_model() → inference - │ │ - │ cudagraph_capturing=True │ cudagraph_capturing=True - │ runtime_mode=NONE │ runtime_mode=PIECEWISE - │ cuBLAS NOT initialized │ cuBLAS initialized - │ DO NOT capture here │ Safe to capture -``` From f790a57a209acef356f35f68b86cf1e068d22daa Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 10 Apr 2026 15:50:34 +0000 Subject: [PATCH 12/13] Fix A_log/dt_bias weight loading to use weight_loader for TP sharding In Apriel2GatedDeltaNet and Apriel2KDAMixer, A_log and dt_bias parameters were loaded with .data.copy_() which bypasses TP sharding. This caused a size mismatch (e.g. 16 vs 32) when running with tensor_parallel_size > 1. Switch to weight_loader() which correctly handles the sharding. Co-Authored-By: Claude Opus 4.6 --- fast_llm_external_models/apriel2/vllm/modeling_apriel2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 20f07d39f..6049f913a 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -1525,10 +1525,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self.norm.weight.data.copy_(weight) loaded.add("norm.weight") elif name == "A_log": - self.A_log.data.copy_(weight) + self.A_log.weight_loader(self.A_log, weight) loaded.add("A_log") elif name == "dt_bias": - self.dt_bias.data.copy_(weight) + self.dt_bias.weight_loader(self.dt_bias, weight) loaded.add("dt_bias") return loaded @@ -1963,10 +1963,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self.g_b_proj.weight_loader(self.g_b_proj.weight, weight) loaded.add("g_b_proj.weight") elif name == "A_log": - self.A_log.data.copy_(weight) + self.A_log.weight_loader(self.A_log, weight) loaded.add("A_log") elif name == "dt_bias": - self.dt_bias.data.copy_(weight) + self.dt_bias.weight_loader(self.dt_bias, weight) loaded.add("dt_bias") return loaded From 5fd8e505e36cc0001cb6387278585943416fad4b Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 10 Apr 2026 15:52:50 +0000 Subject: [PATCH 13/13] Add standalone vLLM plugin installer (no full Fast-LLM dependency) Adds apriel2-vllm-plugin/pyproject.toml that installs only the vLLM plugin entry point and its minimal dependencies (torch, transformers, einops). This avoids pulling in the full Fast-LLM training framework. Usage: pip install -e ./apriel2-vllm-plugin/ Co-Authored-By: Claude Opus 4.6 --- apriel2-vllm-plugin/pyproject.toml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 apriel2-vllm-plugin/pyproject.toml diff --git a/apriel2-vllm-plugin/pyproject.toml b/apriel2-vllm-plugin/pyproject.toml new file mode 100644 index 000000000..e6a4dc07e --- /dev/null +++ b/apriel2-vllm-plugin/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + +[project] +name = "apriel2-vllm-plugin" +version = "0.1.0" +description = "Standalone vLLM plugin for Apriel2 models (extracted from Fast-LLM)" +requires-python = ">=3.12" +dependencies = [ + "torch", + "transformers", + "einops", +] + +[project.entry-points."vllm.general_plugins"] +apriel2 = "fast_llm_external_models.apriel2.vllm.config_convertor:register" + +[tool.setuptools.packages.find] +where = [".."] +include = [ + "fast_llm_external_models", + "fast_llm_external_models.apriel2", + "fast_llm_external_models.apriel2.vllm", +] + +[tool.setuptools.package-dir] +"" = ".."