Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions apriel2-vllm-plugin/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "apriel2-vllm-plugin"
version = "0.1.0"
description = "Standalone vLLM plugin for Apriel2 models (extracted from Fast-LLM)"
requires-python = ">=3.12"
dependencies = [
"torch",
"transformers",
"einops",
]

[project.entry-points."vllm.general_plugins"]
apriel2 = "fast_llm_external_models.apriel2.vllm.config_convertor:register"

[tool.setuptools.packages.find]
where = [".."]
include = [
"fast_llm_external_models",
"fast_llm_external_models.apriel2",
"fast_llm_external_models.apriel2.vllm",
]

[tool.setuptools.package-dir]
"" = ".."
46 changes: 45 additions & 1 deletion fast_llm_external_models/apriel2/vllm/config_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
"""

from vllm import ModelRegistry
from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP,
HybridAttentionMambaModelConfig,
MambaModelConfig,
VerifyAndUpdateConfig,
)
from vllm.transformers_utils.model_arch_config_convertor import (
MODEL_ARCH_CONFIG_CONVERTORS,
ModelArchConfigConvertorBase,
Expand Down Expand Up @@ -66,6 +72,39 @@ def get_head_size(self) -> int:
return self._get_first_attention_block().get("head_size", 0)


class Apriel2ModelConfig(VerifyAndUpdateConfig):
"""Config handler for Apriel2 models with heterogeneous mixer types.

Apriel2 can be pure-attention, pure-mamba, or hybrid (attention + mamba)
depending on the decoder config. vLLM's default ``is_hybrid`` dispatch
calls ``HybridAttentionMambaModelConfig`` which crashes for pure-mamba
models (``ZeroDivisionError`` when ``num_kv_heads=0``).

This handler inspects ``layers_block_type`` on the HF config to determine
the model composition and routes to the correct config handler.
"""

@staticmethod
def verify_and_update_config(vllm_config) -> None:
hf_config = vllm_config.model_config.hf_config
layer_types = getattr(hf_config, "layers_block_type", None)

if layer_types is None:
# Fallback: no layer type info — assume standard transformer.
return

has_attention = any(t == "attention" for t in layer_types)
has_mamba = any(t == "mamba" for t in layer_types)

if has_attention and has_mamba:
# Hybrid: attention + mamba page size alignment required.
HybridAttentionMambaModelConfig.verify_and_update_config(vllm_config)
elif has_mamba:
# Pure mamba: enable FULL_AND_PIECEWISE, set mamba_block_size.
MambaModelConfig.verify_and_update_config(vllm_config)
# Pure attention: no special config needed.


def register():
"""Register Apriel2 models and config convertors with vLLM.

Expand Down Expand Up @@ -126,7 +165,7 @@ def register():
# Best-effort only; vLLM can still proceed with the generic config.
pass

# Register model class
# Register model class and config handler.
# Note: some exported checkpoints may list "Apriel2ForConditionalGeneration"
# in config.json's "architectures". vLLM's model selection is driven by that
# field, so we alias it to the same vLLM implementation for text-only usage.
Expand All @@ -135,3 +174,8 @@ def register():
arch,
"fast_llm_external_models.apriel2.vllm:Apriel2ForCausalLM",
)
# Register in MODELS_CONFIG_MAP so vLLM calls our handler instead of
# relying on the is_hybrid class attribute dispatch (which can't handle
# models that are sometimes hybrid, sometimes pure-mamba).
if arch not in MODELS_CONFIG_MAP:
MODELS_CONFIG_MAP[arch] = Apriel2ModelConfig
Loading