From bca107417eb36f77ecb1260b1219e9308f0e933c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 1 Apr 2026 17:13:14 -0400 Subject: [PATCH 1/9] Support transformers 4.x and 5.x simultaneously - Widen transformers version constraint to >=4.57.3,<6.0.0 - Version-gate PretrainedConfig init (__init__ vs __post_init__) and dtype attribute (torch_dtype vs dtype) using dataclasses.is_dataclass detection - Fall back to transformers.modeling_utils.no_init_weights for 4.x - Support both rope_parameters (5.x) and rope_theta/rope_scaling (4.x) in Llama import/export config - Handle both attribute paths for vision_tower in multimodal HF model test - Fix mtp_llama LlamaRotaryEmbedding to handle both rope config formats - Add _gdn_fla_available and _kda_fla_available flags to apriel2; use them to properly skip backup SSM tests when fla kernels are absent - Update CLAUDE.md with redirect-to-file and external model test guidance Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 170 ++++++++++++++++++ fast_llm/engine/inference/config.py | 39 ++-- fast_llm/engine/inference/huggingface.py | 10 +- fast_llm/models/gpt/conversion/llama.py | 150 ++++++++++------ fast_llm/models/gpt/huggingface.py | 4 +- fast_llm/models/multimodal/huggingface.py | 4 +- .../apriel2/modeling_apriel2.py | 6 +- .../mtp_llama/modeling_mtp_llama.py | 25 ++- setup.cfg | 2 +- tests/layers/test_ssm.py | 8 +- tests/models/test_checkpoint.py | 9 +- 11 files changed, 348 insertions(+), 79 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..e736cfa69 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,170 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +### Virtual environment + +The project uses a venv at `venv/`. Always activate it before running any Python commands: + +```bash +source venv/bin/activate +``` + +### Installation + +```bash +# Full install with GPU support (requires CUDA) +pip install -e ".[CORE,OPTIONAL,DEV]" + +# CPU-only install (for IDE support, no GPU required) +FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE,DEV]" --no-build-isolation +``` + +### Pre-commit hooks + +```bash +pip install pre-commit +pre-commit install +``` + +Hooks run Black (line length 119), isort, autoflake, and pyupgrade automatically on commit. + +### Running tests + +Always redirect output to a file to avoid truncation, e.g. `pytest ... 2>&1 | tee /tmp/pytest_out.txt`. + +```bash +# All tests +pytest -v -n 8 tests/ + +# Single test file or function +pytest -v tests/layers/test_lm_losses.py +pytest -v tests/layers/test_lm_losses.py::test_name + +# Run extra-slow tests (disabled by default) +pytest -v -n 8 --run-extra-slow tests/ + +# Filter by model type +pytest -v -n 8 --models gpt tests/ + +# Test Triton kernels on CPU (no GPU required) +TRITON_INTERPRET=1 pytest -v tests/layers/test_lm_losses.py +``` + +The test suite sets `FAST_LLM_SKIP_TRITON_AUTOTUNE=TRUE` automatically. Tests that require distributed execution spawn child processes via `torchrun`. `TRITON_INTERPRET=1` enables the Triton interpreter so Triton kernels run on CPU — use this when developing or debugging Triton code without a GPU. + +When working with external models (`fast_llm_external_models/`), also run: + +```bash +pytest -v -n 8 fast_llm_external_models/tests/ +``` + +### CLI + +```bash +# General form +fast-llm [--config config.yaml] [key=value overrides...] + +# Validate config without running +fast-llm train gpt --config config.yaml --validate + +# Example: train GPT +fast-llm train gpt --config examples/mistral-4-node-benchmark.yaml +``` + +## Architecture + +### Configuration system (`fast_llm/config.py`) + +The core infrastructure. Every config is a frozen dataclass decorated with `@config_class()` that inherits from `Config`. Fields use `Field(default=..., desc=..., hint=FieldHint.X)` with hints that control serialization verbosity and validation: + +- `FieldHint.architecture` — defines model structure; compared across checkpoints +- `FieldHint.core` — always required explicitly +- `FieldHint.optional/performance/stability/feature/expert` — optional tuning knobs +- `FieldHint.derived` — computed from other fields, never serialized + +Dynamic dispatch (for YAML `type:` keys) uses `@config_class(dynamic_type={BaseClass: "name"})`. The registry enables subclass selection from config files. + +`RunnableConfig` (in `fast_llm/engine/config_utils/runnable.py`) extends `Config` with CLI parsing. `fast-llm train gpt` chains two levels of dynamic type dispatch: `train` selects the trainer subcommand, `gpt` selects `GPTModelConfig`. + +**Important:** Config modules (`config.py` files) must not import heavy third-party packages (torch, numpy, etc.) at the top level — only barebone dependencies — so configs can be validated without a full GPU environment. + +### Engine (`fast_llm/engine/`) + +The training engine is model-agnostic. Key components: + +- **`distributed/`** — `DistributedConfig` defines tensor/pipeline/data/sequence parallelism. `Distributed` manages NCCL process groups and knows which ranks are peers in each dimension. + +- **`base_model/`** — `BaseModel` (abstract) and `LayerBase` define the layer interface. A model is a flat list of `Layer` objects returned by `get_layers()`. Layers are the unit of pipeline parallelism. + +- **`multi_stage/`** — `MultiStageModel` wraps a `BaseModel` and handles: + - Splitting layers across pipeline stages + - ZeRO-1/2/3 weight/gradient/optimizer-state sharding via `FSDP` + - Tied parameter management across stages + +- **`schedule/`** — `Schedule` builds the micro-batch execution graph; `ScheduleRunner` executes it, orchestrating pipeline-parallel forward/backward passes with gradient accumulation. + +- **`optimizer/`** — AdamW implementation in `fast_llm/functional/triton/adam.py`; `Optimizer` manages `ParamGroup`s with per-group LR schedules. + +- **`training/`** — `Trainer` base class wires everything together. Subclasses (e.g., `GPTTrainer`) provide model-specific data loading. Training loop, checkpointing, evaluation, and W&B logging are handled here. + +- **`checkpoint/`** — Supports Fast-LLM distributed format, safetensors, and HuggingFace format conversion. + +### Layers (`fast_llm/layers/`) + +Reusable building blocks consumed by models: + +- `common/` — `Linear` (with tensor-parallel variants), normalization (LayerNorm, RMSNorm), PEFT (LoRA) +- `attention/` — Multi-head/grouped-query attention, RoPE embeddings +- `decoder/` — `TransformerBlock` composing attention + MLP, various MLP variants (dense, MoE) +- `language_model/` — `LanguageModelEmbedding`, `LanguageModelHead`, loss functions (CE, entropy, z-loss, DPO, GRPO) +- `ssm/` — State space model layers (Mamba) +- `vision/` — Vision encoder layers for multimodal models + +### Models (`fast_llm/models/`) + +Concrete model implementations: + +- `gpt/` — The main model family. `GPTBaseModel` assembles embedding + decoder blocks + LM head. `GPTModelConfig` registers HuggingFace checkpoint converters for Llama, Mistral, Mixtral, Qwen2, etc. `GPTTrainer` is the entry point for `fast-llm train gpt`. +- `multimodal/` — Vision-language model built on top of GPT. + +### Functional / Triton kernels (`fast_llm/functional/`) + +Low-level ops with optional Triton acceleration. Triton kernels live in `fast_llm/functional/triton/` and fall back to PyTorch when unavailable. Key kernels: fused entropy loss, z-loss, Adam, sparse linear (MoE), GRPO loss. + +`fast_llm/functional/triton/__init__.py` is the Triton entry point — it handles import errors, exposes `triton_available`/`triton_interpret` flags, and contains workarounds for Triton interpreter bugs. If a third-party Triton bug needs fixing, monkeypatch it here rather than editing `third_party/`. + +**`third_party/` is read-only.** Never edit files under `third_party/`. Fix issues by monkeypatching the relevant module attribute in Fast-LLM code (typically `fast_llm/functional/triton/__init__.py`). + +### Data (`fast_llm/data/`) + +- `dataset/` — Memmap, blended, concatenated, FIM, random, and streaming datasets +- `data/gpt/` — GPT-specific data pipeline (tokenized memmap sequences) +- `preparation/` — Offline dataset preprocessing tools +- `document/` — Document-level abstractions for variable-length inputs + +## Testing Conventions + +Tests live in `tests/`. The following patterns work well in this codebase. + +**Structure:** +- Prefer thin test bodies: construct inputs, call the function, compare outputs. Put expected-value derivation in a helper dataclass with `@cached_property` fields built up step by step. +- Return `None` from an `expected_*` property when a feature flag is off so the test body stays unconditional. + +**Parametrization:** +- Generate test cases as a cross-product of `base_cases × variants` via list comprehension with a `_make_name` helper and a filter clause for invalid combinations. +- Include boundary inputs (e.g., sequences shorter than a parameter, zero padding) as named base cases with explanatory comments. + +**Reference implementations:** +- Reference/ground-truth functions in tests must stay independently correct. Never change a reference to match new implementation behavior — if they disagree, suspect the new implementation first. +- Prefer plain Python loops over tensor ops in reference helpers to stay clearly independent from the implementation. + +## Code Style + +- **Imports**: Third-party → `import package.module` (keep fully qualified). First-party → `from fast_llm.module import Thing`. No relative imports. Optional/slow imports inside methods or under `if typing.TYPE_CHECKING:`. +- **Naming**: No abbreviations (use `batch_size` not `bs`). Private members get a single `_` prefix; never use `__`. Keep public interfaces lean. +- **Types**: Always type-hint public interfaces. Use modern syntax (`X | Y`, `list[T]` not `List[T]`). +- **Paths**: Use `pathlib.Path`, not `os.path`. +- **Python version**: 3.12+. diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index d19e2478d..59966a99d 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -1,4 +1,5 @@ import copy +import dataclasses import logging import os import pathlib @@ -12,20 +13,34 @@ logger = logging.getLogger(__name__) +_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) + class HuggingfaceModelConfig(transformers.PretrainedConfig): model_type = "fast_llm" model_config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig + fast_llm_config: FastLLMModelConfig | None = None + use_cache: bool = True + + if _TRANSFORMERS_V5: + + def __post_init__(self, **kwargs): + # Needed for `to_diff_dict` (`__repr__`) + if self.fast_llm_config is None: + self.fast_llm_config = self.model_config_class() + super().__post_init__(**kwargs) + if self.dtype is not None: + assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch + + else: - def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): - # Needed for `to_diff_dict` (`__repr__`) - if fast_llm_config is None: - fast_llm_config = self.model_config_class() - self.fast_llm_config = fast_llm_config - self.use_cache = kwargs.pop("use_cache", True) - super().__init__(**kwargs) - if self.torch_dtype is not None: - assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch + def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): + # Needed for `to_diff_dict` (`__repr__`) + self.fast_llm_config = fast_llm_config if fast_llm_config is not None else self.model_config_class() + self.use_cache = kwargs.pop("use_cache", True) + super().__init__(**kwargs) + if self.torch_dtype is not None: + assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None: # Hack the method to save at the right place. @@ -88,9 +103,9 @@ def _get_config_dict( ) metadata = cls.model_config_class.load_metadata(pretrained) updates = {} - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None: - updates[("distributed", "compute_dtype")] = torch_dtype + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + if dtype is not None: + updates[("distributed", "compute_dtype")] = dtype fast_llm_config = cls.model_config_class.from_metadata( pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 5a07bd51b..d5fb3cf7e 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -18,6 +18,12 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.utils import Assert +try: + from transformers.initialization import no_init_weights as transformers_no_init_weights +except ImportError: + from transformers.modeling_utils import no_init_weights as transformers_no_init_weights + + logger = logging.getLogger(__name__) @@ -38,7 +44,7 @@ def __init__( **kwargs, ): if config is None: - config = self.config_class(fast_llm_model.config) + config = self.config_class(fast_llm_config=fast_llm_model.config) assert self.runner_class.model_class.config_class is config.model_config_class assert config.fast_llm_config is fast_llm_model.config @@ -70,7 +76,7 @@ def __init__( # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model - with transformers.modeling_utils.no_init_weights(): + with transformers_no_init_weights(): self.post_init() if fast_llm_model.config.multi_stage.zero_stage == 3: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 38dc38586..243e09ade 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -1,7 +1,9 @@ +import dataclasses import logging import typing import torch +import transformers from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( @@ -30,6 +32,8 @@ from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div, safe_merge_dicts +_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) + logger = logging.getLogger(__name__) @@ -188,36 +192,68 @@ def import_weight( class LlamaAttentionConverter: @classmethod def import_config(cls, config: dict) -> dict: - try: - rope_type = config["rope_scaling"]["rope_type"] - except (KeyError, TypeError): - rope_type = "default" - rotary_config = { - "type": rope_type, - "theta": config["rope_theta"], - } - if rope_type == "default": - pass - elif rope_type == "llama3": - rotary_config.update( - { - "scale_factor": config["rope_scaling"]["factor"], - "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], - "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], - } - ) - elif rope_type == "yarn": - rotary_config.update( - { - "attention_factor": config["rope_scaling"]["attention_factor"], - "beta_fast": config["rope_scaling"]["beta_fast"], - "beta_slow": config["rope_scaling"]["beta_slow"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], - } - ) + # transformers 5.x consolidates rope_theta + rope_scaling into rope_parameters + if "rope_parameters" in config: + rope_params = config["rope_parameters"] + rope_type = rope_params.get("rope_type", "default") + rotary_config = { + "type": rope_type, + "theta": rope_params["rope_theta"], + } + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_params["factor"], + "low_frequency_factor": rope_params["low_freq_factor"], + "high_frequency_factor": rope_params["high_freq_factor"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_params["attention_factor"], + "beta_fast": rope_params["beta_fast"], + "beta_slow": rope_params["beta_slow"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + # transformers 4.x format: rope_theta at top level, rope_scaling separate + try: + rope_type = config["rope_scaling"]["rope_type"] + except (KeyError, TypeError): + rope_type = "default" + rotary_config = { + "type": rope_type, + "theta": config["rope_theta"], + } + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": config["rope_scaling"]["factor"], + "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], + "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": config["rope_scaling"]["attention_factor"], + "beta_fast": config["rope_scaling"]["beta_fast"], + "beta_slow": config["rope_scaling"]["beta_slow"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") out = { "rotary": rotary_config, "heads": config["num_attention_heads"], @@ -235,36 +271,46 @@ def import_config(cls, config: dict) -> dict: def export_config(cls, config: AttentionConfig) -> dict: cls._check_config(config) Assert.eq(config.softmax_scale_power, 0.5) - out = { + rope_parameters = {"rope_theta": config.rotary.theta} + if type(config.rotary) is DefaultRotaryConfig: + rope_parameters["rope_type"] = "default" + elif type(config.rotary) is Llama3RotaryConfig: + rope_parameters.update( + { + "rope_type": "llama3", + "factor": config.rotary.scale_factor, + "low_freq_factor": config.rotary.low_frequency_factor, + "high_freq_factor": config.rotary.high_frequency_factor, + "original_max_position_embeddings": config.rotary.original_context_length, + } + ) + elif type(config.rotary) is YarnRotaryConfig: + rope_parameters.update( + { + "rope_type": "yarn", + "attention_factor": config.rotary.attention_factor, + "beta_fast": config.rotary.beta_fast, + "beta_slow": config.rotary.beta_slow, + "original_max_position_embeddings": config.rotary.original_context_length, + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + + common = { "num_attention_heads": config.heads, "num_key_value_heads": config.head_groups, "head_dim": config.head_size, "attention_bias": config.add_linear_biases, "attention_dropout": config.dropout, - "rope_theta": config.rotary.theta, } - if type(config.rotary) is DefaultRotaryConfig: - pass - elif type(config.rotary) is Llama3RotaryConfig: - out["rope_scaling"] = { - "rope_type": "llama3", - "factor": config.rotary.scale_factor, - "low_freq_factor": config.rotary.low_frequency_factor, - "high_freq_factor": config.rotary.high_frequency_factor, - "original_max_position_embeddings": config.rotary.original_context_length, - } - elif type(config.rotary) is YarnRotaryConfig: - out["rope_scaling"] = { - "rope_type": "yarn", - "attention_factor": config.rotary.attention_factor, - "beta_fast": config.rotary.beta_fast, - "beta_slow": config.rotary.beta_slow, - "original_max_position_embeddings": config.rotary.original_context_length, - } + if _TRANSFORMERS_V5: + return {**common, "rope_parameters": rope_parameters} else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - - return out + out = {**common, "rope_theta": rope_parameters["rope_theta"]} + if type(config.rotary) is not DefaultRotaryConfig: + out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} + return out @classmethod def _check_config(cls, config: AttentionConfig) -> None: diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 55c30c7ee..1fcb3fc25 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -20,7 +20,9 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): model_type = "fast_llm_gpt" model_config_class = GPTModelConfig - fast_llm_config: GPTModelConfig + + if typing.TYPE_CHECKING: + fast_llm_config: GPTModelConfig class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): diff --git a/fast_llm/models/multimodal/huggingface.py b/fast_llm/models/multimodal/huggingface.py index 8bf14d715..93770b446 100644 --- a/fast_llm/models/multimodal/huggingface.py +++ b/fast_llm/models/multimodal/huggingface.py @@ -22,7 +22,9 @@ class HuggingfaceMultiModalModelConfig(HuggingfaceGPTModelConfig): model_type = "fast_llm_multi_modal" model_config_class = MultiModalModelConfig - fast_llm_config: MultiModalModelConfig + + if typing.TYPE_CHECKING: + fast_llm_config: MultiModalModelConfig class HuggingfaceMultiModalModelForCausalLM(HuggingfaceGPTModelForCausalLM): diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 9e82dfc4f..6501efde6 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -48,6 +48,9 @@ fused_recurrent_kda = None fused_kda_gate = None +_gdn_fla_available = chunk_gated_delta_rule is not None and rms_norm_gated is not None +_kda_fla_available = chunk_kda is not None + try: from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn @@ -289,7 +292,7 @@ def get_mask_sizes(self, cache_position, layer_idx): For SSM/linear layers: kv_offset = 0, kv_length = query_length (no KV cache to attend to) """ - query_length = cache_position.shape[0] + query_length = cache_position if isinstance(cache_position, int) else cache_position.shape[0] layer = self.layers[layer_idx] # Handle stochastic layers by getting the active mixer's cache @@ -794,6 +797,7 @@ def setup( hidden_size=hidden_size, num_attention_heads=num_heads, partial_rotary_factor=1.0, + rope_parameters={"rope_theta": rope_theta, "rope_type": "default"}, ) return nn.ModuleDict({"rotary_emb": MistralRotaryEmbedding(config=rotary_config)}) diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index d0e1988f1..be7b7bd6e 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -56,21 +56,38 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): def __init__(self, config: MTPLlamaConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + # Support both transformers 4.x (rope_theta + rope_scaling) and 5.x (rope_parameters) + if hasattr(config, "rope_parameters"): + self.rope_type = config.rope_parameters.get("rope_type", "default") + elif hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = self.compute_default_rope_parameters + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq + @staticmethod + def compute_default_rope_parameters(config, device=None, seq_len=None): + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: diff --git a/setup.cfg b/setup.cfg index 955702907..0c963c1d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.57.3,<5.0.0 + transformers>=4.57.3,<6.0.0 hf-transfer>=0.1.9 datasets>=4.4.1 huggingface-hub>=0.36.0 diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index 9c31ec80f..064014bdb 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -20,12 +20,16 @@ Apriel2GatedDeltaNet, Apriel2Mamba, KimiDeltaAttention, + _gdn_fla_available, + _kda_fla_available, is_fast_path_available, ) except ImportError: Apriel2GatedDeltaNet = None Apriel2Mamba = None KimiDeltaAttention = None + _gdn_fla_available = False + _kda_fla_available = False is_fast_path_available = False HIDDEN_SIZE = 16 @@ -104,7 +108,7 @@ def _compare_mixers( "use_backup", [ pytest.param(False, marks=pytest.mark.skipif(not _fast_gdn_available, reason="FLA not available")), - True, + pytest.param(True, marks=pytest.mark.skipif(not _gdn_fla_available, reason="GDN fla kernels not available")), ], ids=["fast", "backup"], ) @@ -146,7 +150,7 @@ def test_gdn(testing_device, use_backup, monkeypatch): "use_backup", [ pytest.param(False, marks=pytest.mark.skipif(not _kda_available, reason="KDA fused kernels not available")), - True, + pytest.param(True, marks=pytest.mark.skipif(not _kda_fla_available, reason="KDA fla kernels not available")), ], ids=["fast", "backup"], ) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 5f0f5a80f..93b468dfd 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -391,13 +391,16 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic hidden_states = output.hidden_states + (output.logits,) # Llava models doesn't return vision hidden states, so we run the vision model directly instead. if model_testing_config.model_type == "multimodal": - if hasattr(model, "vision_tower"): - vision_output = model.vision_tower( + vision_model = ( + model.model if hasattr(model, "model") and hasattr(model.model, "vision_tower") else model + ) + if hasattr(vision_model, "vision_tower"): + vision_output = vision_model.vision_tower( pixel_values=kwargs["pixel_values"], image_sizes=kwargs["image_sizes"], output_hidden_states=True, ) - adapter_output = model.multi_modal_projector(vision_output.hidden_states[-1]) + adapter_output = vision_model.multi_modal_projector(vision_output.hidden_states[-1]) hidden_states = vision_output.hidden_states + (adapter_output,) + hidden_states hidden_states_ref_ = hidden_states_ref.copy() # Adjust the vision hidden states From a98e97aa32affc55becfc4f900aa4171f7c38321 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 1 Apr 2026 21:31:14 -0400 Subject: [PATCH 2/9] Add transformers 5.x support to external models while preserving 4.x compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - apriel2/modeling_apriel2.py: add _TRANSFORMERS_V5 flag; fix _tied_weights_keys to dict format for 5.x (list for 4.x); add rope_parameters to PixtralRotaryEmbedding SimpleNamespace config - mtp_llama/modeling_mtp_llama.py: add _TRANSFORMERS_V5 flag; fix _tied_weights_keys - apriel2/conversion/llava/config.py: handle 5.x rope_parameters dict in text and vision configs alongside 4.x rope_theta - apriel2/conversion/llava/plan.py: version-conditional source weight key prefixes (5.x LlavaForConditionalGeneration adds model. prefix to submodules) - test_cache_contracts.py: update DynamicLayer.get_mask_sizes calls to pass int in 5.x (query_length) vs tensor in 4.x; update sdpa_mask signature for 5.x (q_length/q_offset) - test_convert_from_llava.py: use version-conditional embed_tokens source key - test_equivalence.py: fix get_image_features handling — 5.x returns BaseModelOutput with projected features in pooler_output (not last_hidden_state) Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 170 ------------------ .../apriel2/conversion/llava/config.py | 4 +- .../apriel2/conversion/llava/plan.py | 37 ++-- .../apriel2/modeling_apriel2.py | 12 +- .../mtp_llama/modeling_mtp_llama.py | 5 +- .../test_apriel2/test_cache_contracts.py | 119 ++++++++---- .../test_apriel2/test_convert_from_llava.py | 9 +- .../tests/test_apriel2/test_equivalence.py | 12 +- 8 files changed, 145 insertions(+), 223 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index e736cfa69..000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,170 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Commands - -### Virtual environment - -The project uses a venv at `venv/`. Always activate it before running any Python commands: - -```bash -source venv/bin/activate -``` - -### Installation - -```bash -# Full install with GPU support (requires CUDA) -pip install -e ".[CORE,OPTIONAL,DEV]" - -# CPU-only install (for IDE support, no GPU required) -FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE,DEV]" --no-build-isolation -``` - -### Pre-commit hooks - -```bash -pip install pre-commit -pre-commit install -``` - -Hooks run Black (line length 119), isort, autoflake, and pyupgrade automatically on commit. - -### Running tests - -Always redirect output to a file to avoid truncation, e.g. `pytest ... 2>&1 | tee /tmp/pytest_out.txt`. - -```bash -# All tests -pytest -v -n 8 tests/ - -# Single test file or function -pytest -v tests/layers/test_lm_losses.py -pytest -v tests/layers/test_lm_losses.py::test_name - -# Run extra-slow tests (disabled by default) -pytest -v -n 8 --run-extra-slow tests/ - -# Filter by model type -pytest -v -n 8 --models gpt tests/ - -# Test Triton kernels on CPU (no GPU required) -TRITON_INTERPRET=1 pytest -v tests/layers/test_lm_losses.py -``` - -The test suite sets `FAST_LLM_SKIP_TRITON_AUTOTUNE=TRUE` automatically. Tests that require distributed execution spawn child processes via `torchrun`. `TRITON_INTERPRET=1` enables the Triton interpreter so Triton kernels run on CPU — use this when developing or debugging Triton code without a GPU. - -When working with external models (`fast_llm_external_models/`), also run: - -```bash -pytest -v -n 8 fast_llm_external_models/tests/ -``` - -### CLI - -```bash -# General form -fast-llm [--config config.yaml] [key=value overrides...] - -# Validate config without running -fast-llm train gpt --config config.yaml --validate - -# Example: train GPT -fast-llm train gpt --config examples/mistral-4-node-benchmark.yaml -``` - -## Architecture - -### Configuration system (`fast_llm/config.py`) - -The core infrastructure. Every config is a frozen dataclass decorated with `@config_class()` that inherits from `Config`. Fields use `Field(default=..., desc=..., hint=FieldHint.X)` with hints that control serialization verbosity and validation: - -- `FieldHint.architecture` — defines model structure; compared across checkpoints -- `FieldHint.core` — always required explicitly -- `FieldHint.optional/performance/stability/feature/expert` — optional tuning knobs -- `FieldHint.derived` — computed from other fields, never serialized - -Dynamic dispatch (for YAML `type:` keys) uses `@config_class(dynamic_type={BaseClass: "name"})`. The registry enables subclass selection from config files. - -`RunnableConfig` (in `fast_llm/engine/config_utils/runnable.py`) extends `Config` with CLI parsing. `fast-llm train gpt` chains two levels of dynamic type dispatch: `train` selects the trainer subcommand, `gpt` selects `GPTModelConfig`. - -**Important:** Config modules (`config.py` files) must not import heavy third-party packages (torch, numpy, etc.) at the top level — only barebone dependencies — so configs can be validated without a full GPU environment. - -### Engine (`fast_llm/engine/`) - -The training engine is model-agnostic. Key components: - -- **`distributed/`** — `DistributedConfig` defines tensor/pipeline/data/sequence parallelism. `Distributed` manages NCCL process groups and knows which ranks are peers in each dimension. - -- **`base_model/`** — `BaseModel` (abstract) and `LayerBase` define the layer interface. A model is a flat list of `Layer` objects returned by `get_layers()`. Layers are the unit of pipeline parallelism. - -- **`multi_stage/`** — `MultiStageModel` wraps a `BaseModel` and handles: - - Splitting layers across pipeline stages - - ZeRO-1/2/3 weight/gradient/optimizer-state sharding via `FSDP` - - Tied parameter management across stages - -- **`schedule/`** — `Schedule` builds the micro-batch execution graph; `ScheduleRunner` executes it, orchestrating pipeline-parallel forward/backward passes with gradient accumulation. - -- **`optimizer/`** — AdamW implementation in `fast_llm/functional/triton/adam.py`; `Optimizer` manages `ParamGroup`s with per-group LR schedules. - -- **`training/`** — `Trainer` base class wires everything together. Subclasses (e.g., `GPTTrainer`) provide model-specific data loading. Training loop, checkpointing, evaluation, and W&B logging are handled here. - -- **`checkpoint/`** — Supports Fast-LLM distributed format, safetensors, and HuggingFace format conversion. - -### Layers (`fast_llm/layers/`) - -Reusable building blocks consumed by models: - -- `common/` — `Linear` (with tensor-parallel variants), normalization (LayerNorm, RMSNorm), PEFT (LoRA) -- `attention/` — Multi-head/grouped-query attention, RoPE embeddings -- `decoder/` — `TransformerBlock` composing attention + MLP, various MLP variants (dense, MoE) -- `language_model/` — `LanguageModelEmbedding`, `LanguageModelHead`, loss functions (CE, entropy, z-loss, DPO, GRPO) -- `ssm/` — State space model layers (Mamba) -- `vision/` — Vision encoder layers for multimodal models - -### Models (`fast_llm/models/`) - -Concrete model implementations: - -- `gpt/` — The main model family. `GPTBaseModel` assembles embedding + decoder blocks + LM head. `GPTModelConfig` registers HuggingFace checkpoint converters for Llama, Mistral, Mixtral, Qwen2, etc. `GPTTrainer` is the entry point for `fast-llm train gpt`. -- `multimodal/` — Vision-language model built on top of GPT. - -### Functional / Triton kernels (`fast_llm/functional/`) - -Low-level ops with optional Triton acceleration. Triton kernels live in `fast_llm/functional/triton/` and fall back to PyTorch when unavailable. Key kernels: fused entropy loss, z-loss, Adam, sparse linear (MoE), GRPO loss. - -`fast_llm/functional/triton/__init__.py` is the Triton entry point — it handles import errors, exposes `triton_available`/`triton_interpret` flags, and contains workarounds for Triton interpreter bugs. If a third-party Triton bug needs fixing, monkeypatch it here rather than editing `third_party/`. - -**`third_party/` is read-only.** Never edit files under `third_party/`. Fix issues by monkeypatching the relevant module attribute in Fast-LLM code (typically `fast_llm/functional/triton/__init__.py`). - -### Data (`fast_llm/data/`) - -- `dataset/` — Memmap, blended, concatenated, FIM, random, and streaming datasets -- `data/gpt/` — GPT-specific data pipeline (tokenized memmap sequences) -- `preparation/` — Offline dataset preprocessing tools -- `document/` — Document-level abstractions for variable-length inputs - -## Testing Conventions - -Tests live in `tests/`. The following patterns work well in this codebase. - -**Structure:** -- Prefer thin test bodies: construct inputs, call the function, compare outputs. Put expected-value derivation in a helper dataclass with `@cached_property` fields built up step by step. -- Return `None` from an `expected_*` property when a feature flag is off so the test body stays unconditional. - -**Parametrization:** -- Generate test cases as a cross-product of `base_cases × variants` via list comprehension with a `_make_name` helper and a filter clause for invalid combinations. -- Include boundary inputs (e.g., sequences shorter than a parameter, zero padding) as named base cases with explanatory comments. - -**Reference implementations:** -- Reference/ground-truth functions in tests must stay independently correct. Never change a reference to match new implementation behavior — if they disagree, suspect the new implementation first. -- Prefer plain Python loops over tensor ops in reference helpers to stay clearly independent from the implementation. - -## Code Style - -- **Imports**: Third-party → `import package.module` (keep fully qualified). First-party → `from fast_llm.module import Thing`. No relative imports. Optional/slow imports inside methods or under `if typing.TYPE_CHECKING:`. -- **Naming**: No abbreviations (use `batch_size` not `bs`). Private members get a single `_` prefix; never use `__`. Keep public interfaces lean. -- **Types**: Always type-hint public interfaces. Use modern syntax (`X | Y`, `list[T]` not `List[T]`). -- **Paths**: Use `pathlib.Path`, not `os.path`. -- **Python version**: 3.12+. diff --git a/fast_llm_external_models/apriel2/conversion/llava/config.py b/fast_llm_external_models/apriel2/conversion/llava/config.py index ac8f70dba..4d2e4d934 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/config.py +++ b/fast_llm_external_models/apriel2/conversion/llava/config.py @@ -24,7 +24,7 @@ def convert_config(llava_config: dict) -> dict: hidden_size = text_config["hidden_size"] num_heads = text_config["num_attention_heads"] num_kv_heads = text_config["num_key_value_heads"] - rope_theta = text_config["rope_theta"] + rope_theta = text_config.get("rope_theta") or text_config.get("rope_parameters", {}).get("rope_theta", 10000.0) # Use explicit head_dim if available (some models have head_dim != hidden_size // num_heads) # Note: MistralConfig.head_dim is None by default, so we must check for None explicitly head_dim = text_config.get("head_dim") @@ -98,7 +98,7 @@ def _convert_vision_config(llava_config: dict) -> dict: num_heads = vision_config["num_attention_heads"] num_layers = vision_config["num_hidden_layers"] intermediate_size = vision_config["intermediate_size"] - rope_theta = vision_config["rope_theta"] + rope_theta = vision_config.get("rope_theta") or vision_config.get("rope_parameters", {}).get("rope_theta", 10000.0) patch_size = vision_config["patch_size"] num_channels = vision_config["num_channels"] # Use explicit head_dim if available diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index a97e46c1a..e4f6d65bd 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -1,6 +1,7 @@ """Llava to Apriel2 weight conversion plan.""" from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W +from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V5 def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: @@ -14,26 +15,42 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: num_text_layers = llava_config.get("text_config", {}).get("num_hidden_layers", 0) num_vision_layers = llava_config.get("vision_config", {}).get("num_hidden_layers", 0) + # In transformers 5.x, LlavaForConditionalGeneration adds a "model." prefix to most submodules. + # 4.x: language_model.model.*, vision_tower.*, multi_modal_projector.* + # 5.x: model.language_model.model.*, model.vision_tower.*, model.multi_modal_projector.* + # lm_head stays as language_model.lm_head.weight in both versions. + if _TRANSFORMERS_V5: + llava_text_model = W("model", "language_model", "model") + llava_vision_tower = W("model", "vision_tower") + llava_projector = W("model", "multi_modal_projector") + else: + llava_text_model = W("language_model", "model") + llava_vision_tower = W("vision_tower") + llava_projector = W("multi_modal_projector") + # Static mappings static_mappings = [ - (W("language_model", "model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), + (llava_text_model / "embed_tokens" / "weight", W("model", "embed_tokens", "weight")), (W("language_model", "lm_head", "weight"), W("lm_head", "weight")), - (W("language_model", "model", "norm", "weight"), W("model", "norm", "weight")), + (llava_text_model / "norm" / "weight", W("model", "norm", "weight")), ( - W("vision_tower", "patch_conv", "weight"), + llava_vision_tower / "patch_conv" / "weight", W("model", "vision_encoder", "embeddings", "patch_embeddings", "weight"), ), - (W("vision_tower", "ln_pre", "weight"), W("model", "vision_encoder", "embeddings", "normalization", "weight")), ( - W("multi_modal_projector", "linear_1", "weight"), + llava_vision_tower / "ln_pre" / "weight", + W("model", "vision_encoder", "embeddings", "normalization", "weight"), + ), + ( + llava_projector / "linear_1" / "weight", W("model", "vision_encoder", "adapter", "linear_1", "weight"), ), - (W("multi_modal_projector", "linear_1", "bias"), W("model", "vision_encoder", "adapter", "linear_1", "bias")), + (llava_projector / "linear_1" / "bias", W("model", "vision_encoder", "adapter", "linear_1", "bias")), ( - W("multi_modal_projector", "linear_2", "weight"), + llava_projector / "linear_2" / "weight", W("model", "vision_encoder", "adapter", "linear_2", "weight"), ), - (W("multi_modal_projector", "linear_2", "bias"), W("model", "vision_encoder", "adapter", "linear_2", "bias")), + (llava_projector / "linear_2" / "bias", W("model", "vision_encoder", "adapter", "linear_2", "bias")), ] for src, tgt in static_mappings: @@ -41,7 +58,7 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # Text decoder layers for layer in range(num_text_layers): - llava_layer = W("language_model", "model", "layers", layer) + llava_layer = llava_text_model / "layers" / layer apriel_layer = W("model", "decoder", "blocks", layer) # Attention projections @@ -64,7 +81,7 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # Vision encoder layers for layer in range(num_vision_layers): - llava_layer = W("vision_tower", "transformer", "layers", layer) + llava_layer = llava_vision_tower / "transformer" / "layers" / layer apriel_layer = W("model", "vision_encoder", "encoder", "blocks", layer) # Attention projections diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 6501efde6..4e3fe495e 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -1,5 +1,6 @@ """Apriel2 HuggingFace model implementation.""" +import dataclasses import math import random from types import SimpleNamespace @@ -7,6 +8,7 @@ import torch import torch.nn.functional as F +import transformers from einops import rearrange, repeat from torch import nn from transformers import GenerationMixin, PreTrainedModel @@ -50,6 +52,7 @@ _gdn_fla_available = chunk_gated_delta_rule is not None and rms_norm_gated is not None _kda_fla_available = chunk_kda is not None +_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) try: @@ -784,6 +787,7 @@ def setup( rope_theta=rope_theta, image_size=rotary_config_dict["max_image_size"], patch_size=rotary_config_dict["patch_size"], + rope_parameters={"rope_theta": rope_theta, "rope_type": "default"}, ) return nn.ModuleDict({"rotary_emb": PixtralRotaryEmbedding(config=rotary_config)}) @@ -2439,7 +2443,7 @@ def forward( class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} if _TRANSFORMERS_V5 else ["lm_head.weight"] def __init__(self, config: Apriel2TextConfig): super().__init__(config) @@ -3062,7 +3066,7 @@ class Apriel2ForConditionalGeneration(Apriel2PreTrainedModel, GenerationMixin): """ config_class = Apriel2Config - _tied_weights_keys = [] # No weight tying by default, but can be configured + _tied_weights_keys = {} if _TRANSFORMERS_V5 else [] # No weight tying by default, but can be configured def __init__(self, config: Apriel2Config): super().__init__(config) @@ -3072,7 +3076,9 @@ def __init__(self, config: Apriel2Config): # Handle weight tying if configured if config.tie_word_embeddings: - self._tied_weights_keys = ["lm_head.weight"] + self._tied_weights_keys = ( + {"lm_head.weight": "model.embed_tokens.weight"} if _TRANSFORMERS_V5 else ["lm_head.weight"] + ) self.post_init() diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index be7b7bd6e..bdb43d3f2 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -1,8 +1,10 @@ +import dataclasses from functools import partial from typing import Callable, Optional, Union import torch import torch.utils.checkpoint +import transformers from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache @@ -28,6 +30,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MTPLlamaConfig" +_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) class LlamaRMSNorm(nn.Module): @@ -779,7 +782,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} if _TRANSFORMERS_V5 else ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py index 337ff1fa3..6ed54a57f 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -27,7 +27,11 @@ import pytest import torch -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Cache, _AttentionCache +from fast_llm_external_models.apriel2.modeling_apriel2 import ( + _TRANSFORMERS_V5, + Apriel2Cache, + _AttentionCache, +) # ============================================================================= # SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer @@ -113,7 +117,9 @@ def test_hf_mask_sizes_kv_length( # Verify HF's kv_length follows the expected formula cache_position = torch.arange(1) # Single token decode - hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + ) expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0] assert hf_kv_len == expected_kv_len @@ -130,7 +136,9 @@ def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, hf_dynamic_layer.update(key.clone(), value.clone()) cache_position = torch.arange(1) - _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + ) assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0" @@ -248,7 +256,9 @@ def test_kv_offset_zero_before_window_full( apriel_sliding_cache.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + ) # Verify HF returns 0 offset before window full assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}" @@ -271,7 +281,9 @@ def test_kv_offset_increases_after_window_full( apriel_sliding_cache.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + ) # At window boundary, offset should be 1 assert hf_kv_offset == 1, "HF offset should be 1 at window boundary" @@ -284,7 +296,9 @@ def test_kv_offset_increases_after_window_full( hf_sliding_layer.update(key.clone(), value.clone()) apriel_sliding_cache.update(key.clone(), value.clone()) - hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + ) expected_offset = i + 2 assert hf_kv_offset == expected_offset @@ -306,7 +320,7 @@ def test_kv_length_capped_at_window( apriel_sliding_cache.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position) # HF returns window (window-1 cached + 1 query) assert hf_kv_len == window_size @@ -443,8 +457,12 @@ def test_get_mask_sizes_matches_dynamic_layer(self, attention_config): hf_layer.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) - apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + ) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position, layer_idx=0 + ) assert apr_kv_len == hf_kv_len assert apr_kv_offset == hf_kv_offset @@ -464,8 +482,12 @@ def test_get_mask_sizes_matches_sliding_layer(self, swa_config): hf_layer.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) - apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + ) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes( + cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position, layer_idx=0 + ) assert apr_kv_len == hf_kv_len assert apr_kv_offset == hf_kv_offset @@ -508,13 +530,23 @@ def test_full_attention_decode_can_attend_to_all(self): kv_length = cache.cumulative_length + 1 kv_offset = 0 - mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=causal_mask_function, - ) + if _TRANSFORMERS_V5: + mask = sdpa_mask( + batch_size=1, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) + else: + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) if mask is not None: # Query at position 10 should attend to positions 0-10 @@ -540,13 +572,23 @@ def test_sliding_window_decode_respects_window(self, window_size): kv_offset = max(cumulative - window_size + 1, 0) kv_length = window_size - 1 + 1 # cached + query - mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=sliding_window_causal_mask_function(window_size), - ) + if _TRANSFORMERS_V5: + mask = sdpa_mask( + batch_size=1, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) + else: + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) if mask is not None: query_mask = mask[0, 0, 0, :] @@ -573,14 +615,25 @@ def test_prefill_has_causal_pattern(self): kv_length = cache.cumulative_length kv_offset = 0 - mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=causal_mask_function, - allow_is_causal_skip=False, # Force mask creation - ) + if _TRANSFORMERS_V5: + mask = sdpa_mask( + batch_size=1, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) + else: + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) if mask is not None: # Check causal pattern diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py index f96f5ac40..dc01a66b0 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py +++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py @@ -21,7 +21,7 @@ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration +from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V5, Apriel2ForConditionalGeneration # ============================================================================= # Config Conversion Tests @@ -129,7 +129,12 @@ def test_plan_weight_values_unchanged(self, llava_pixtral_checkpoint): apriel2_weights = execute(plan, source_weights, seed=0) # Check specific weights are identical - source_embed = source_weights["language_model.model.embed_tokens.weight"] + source_embed_key = ( + "model.language_model.model.embed_tokens.weight" + if _TRANSFORMERS_V5 + else "language_model.model.embed_tokens.weight" + ) + source_embed = source_weights[source_embed_key] target_embed = apriel2_weights["model.embed_tokens.weight"] assert torch.equal(source_embed, target_embed) diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index 8734aa02c..c1a814e00 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -94,7 +94,12 @@ def create_inputs(model: LlavaForConditionalGeneration, config: InputConfig, see dummy_pixel = torch.randn(1, 3, h, w) with torch.no_grad(): features = model.get_image_features(dummy_pixel) - num_patches = features[0].shape[0] if isinstance(features, list) else features.shape[1] + if isinstance(features, list): + num_patches = features[0].shape[0] + elif hasattr(features, "pooler_output") and features.pooler_output is not None: + num_patches = features.pooler_output[0].shape[0] + else: + num_patches = features.shape[1] else: num_patches = 0 @@ -164,7 +169,10 @@ def get_pixtral_vision_features(source: LlavaForConditionalGeneration, pixel_val """Get vision features from Pixtral, flattened to [total_patches, hidden].""" features = source.get_image_features(pixel_values) if isinstance(features, list): - features = torch.cat(features, dim=0) + return torch.cat(features, dim=0) + # 5.x: BaseModelOutput with pooler_output = list of projected feature tensors + if hasattr(features, "pooler_output") and features.pooler_output is not None: + return torch.cat(features.pooler_output, dim=0) return features From 237c436fb880b1c3abf37d917d1a0885bd0ae0e6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Apr 2026 15:30:47 -0400 Subject: [PATCH 3/9] Fix MTP Llama converter bugs and hidden state collection - Fix num_blocks off-by-one in import_config (was subtracting 1) - Fix num_hidden_layers off-by-one in export_config (was adding 1) - Fix mtp_heads index off-by-one in get_converters (was prediction_distance - 1) - Fix hidden state collection order in MTPLlamaModel: add embedding before trunk loop and add trunk layer outputs inside the loop, consistent with standard transformers @capture_outputs behavior Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/models/gpt/conversion/mtp_llama.py | 6 +++--- .../mtp_llama/modeling_mtp_llama.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index cb9c5c1f2..f681c4a24 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -58,7 +58,7 @@ def get_converters( converters += cls.block_converter_class.get_converters( config.decoder.last_block_config, f"multi_token_prediction.blocks.{prediction_distance-2}", - f"model.mtp_heads.{prediction_distance - 1}", + f"model.mtp_heads.{prediction_distance - 2}", ) converters += cls.normalization_converter_class.get_converters( config.head.normalization, @@ -73,7 +73,7 @@ class MTPLlamaDecoderConverter(LlamaDecoderConverter): def import_config(cls, config: dict) -> dict: return { "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"] - 1, + "num_blocks": config["num_hidden_layers"], } @classmethod @@ -82,7 +82,7 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict: Assert.custom(isinstance, config, FixedBlockSequenceConfig) return safe_merge_dicts( cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks + 1}, + {"num_hidden_layers": config.num_blocks}, ) diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index bdb43d3f2..1b54f1c5f 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -577,11 +577,13 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - # MTP: The last layer is not part of the shared trunk - for decoder_layer in self.layers[:-1]: - if output_hidden_states: - all_hidden_states += (hidden_states,) + # MTP: The last layer is not part of the shared trunk. + # Always add the initial embedding state first, then add the output of each trunk layer. + # This is consistent with how standard transformers models collect hidden states via @capture_outputs. + if output_hidden_states: + all_hidden_states += (hidden_states,) + for decoder_layer in self.layers[:-1]: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( partial(decoder_layer.__call__, **flash_attn_kwargs), @@ -609,6 +611,9 @@ def forward( hidden_states = layer_outputs[0] + if output_hidden_states: + all_hidden_states += (hidden_states,) + if output_attentions: all_self_attns += (layer_outputs[1],) From 35ae20e136409ab421fb26d799b81bf4d8ee6a2b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Apr 2026 15:50:28 -0400 Subject: [PATCH 4/9] Switch testing tokenizer from santacoder to gpt2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update TOKENIZER_NAME from "bigcode/santacoder" to "gpt2" and update all hardcoded token values in data tests to match the gpt2 vocabulary. Also fix deprecated huggingface_hub.HfFolder.get_token() → get_token(). Co-Authored-By: Claude Sonnet 4.6 --- .../data/preparation/gpt_memmap/prepare.py | 2 +- tests/data/test_blending.py | 26 +- tests/data/test_concatenate.py | 16 +- tests/data/test_dataset_discovery.py | 18 +- tests/data/test_fim.py | 16 +- tests/data/test_image_patch.py | 12 +- tests/data/test_loss_masking_spans.py | 12 +- tests/data/test_preference_spans.py | 14 +- tests/data/test_preparator.py | 34 +-- tests/data/test_sampling.py | 16 +- tests/data/test_slice.py | 26 +- tests/data/test_tokenizer.py | 263 +++++++++--------- tests/utils/dataset.py | 2 +- tests/utils/global_variables.py | 2 +- 14 files changed, 230 insertions(+), 229 deletions(-) diff --git a/fast_llm/data/preparation/gpt_memmap/prepare.py b/fast_llm/data/preparation/gpt_memmap/prepare.py index 70a1e13e8..71aa1199b 100644 --- a/fast_llm/data/preparation/gpt_memmap/prepare.py +++ b/fast_llm/data/preparation/gpt_memmap/prepare.py @@ -80,7 +80,7 @@ def _load_dataset(self) -> datasets.Dataset: return dataset def _get_croissant_metadata(self): - token = huggingface_hub.HfFolder.get_token() + token = huggingface_hub.get_token() try: # Retrieve the dataset metadata in croissant format url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index edbe479cc..1407397f6 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -31,25 +31,25 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, GPT_BLENDED_SAMPLES = [ - [49152, 46, 10, 819, 19, 45], - [45, 69, 17, 86, 38826, 15], - [49152, 83, 80, 20452, 45, 93], - [15, 25, 51, 31, 32348, 64], - [64, 17, 93, 78, 40, 1793], - [1793, 1, 1746, 38, 27, 58], - [93, 90, 39, 6, 75, 9], - [58, 22885, 93, 37, 92, 76], + [50256, 46, 10, 721, 19, 45], + [45, 69, 17, 86, 92, 0], + [50256, 83, 80, 29, 2, 45], + [0, 15, 25, 51, 31, 27], + [27, 0, 64, 17, 93, 78], + [78, 3955, 43, 1, 1395, 38], + [45, 93, 90, 39, 6, 75], + [38, 27, 58, 40692, 93, 37], ] GPT_BLENDED_MIXED_SAMPLES = [ - [49152, 46, 10, 819, 19, 45], + [50256, 46, 10, 721, 19, 45], [25492, 15877, 37874, 8570, 31649, 15521], - [45, 69, 17, 86, 38826, 15], + [45, 69, 17, 86, 92, 0], [3359, 20945, 33437, 32454, 42084, 45942], - [15, 25, 51, 31, 32348, 64], - [64, 17, 93, 78, 40, 1793], + [0, 15, 25, 51, 31, 27], + [27, 0, 64, 17, 93, 78], [15112, 36731, 47864, 35586, 33356, 37537], - [1793, 1, 1746, 38, 27, 58], + [78, 3955, 43, 1, 1395, 38], ] diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 6774374bb..6a232f028 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -12,14 +12,14 @@ from tests.utils.dataset import get_common_test_dataset GPT_CONCATENATED_SAMPLES = [ - [49152, 46, 10, 819, 19, 45], - [45, 69, 17, 86, 38826, 15], - [15, 25, 51, 31, 32348, 64], - [64, 17, 93, 78, 40, 1793], - [1793, 1, 1746, 38, 27, 58], - [58, 22885, 93, 37, 92, 76], - [76, 29, 19, 17365, 93, 46], - [46, 83, 17211, 1, 785, 1023], + [50256, 46, 10, 721, 19, 45], + [45, 69, 17, 86, 92, 0], + [0, 15, 25, 51, 31, 27], + [27, 0, 64, 17, 93, 78], + [78, 3955, 43, 1, 1395, 38], + [38, 27, 58, 40692, 93, 37], + [37, 92, 76, 29, 19, 29499], + [29499, 93, 46, 83, 27159, 1], ] diff --git a/tests/data/test_dataset_discovery.py b/tests/data/test_dataset_discovery.py index 0dd9c31a4..cbe635163 100644 --- a/tests/data/test_dataset_discovery.py +++ b/tests/data/test_dataset_discovery.py @@ -25,7 +25,7 @@ {"type": "memmap", "path": "dataset_0.fast_llm_dataset"}, {"type": "memmap", "path": "dataset_1.fast_llm_dataset"}, ], - "weights": [44883, 43910], + "weights": [47178, 46208], }, ), ( @@ -39,7 +39,7 @@ {"type": "memmap", "path": "dataset0/dataset_0.fast_llm_dataset"}, {"type": "memmap", "path": "dataset1/dataset_1.fast_llm_dataset"}, ], - "weights": [44883, 43910], + "weights": [47178, 46208], }, ), ( @@ -59,7 +59,7 @@ {"type": "memmap", "path": "dataset/dataset_1.fast_llm_dataset"}, {"type": "memmap", "path": "dataset_0.fast_llm_dataset"}, ], - "weights": [43910, 44883], + "weights": [46208, 47178], }, ), ( @@ -78,10 +78,10 @@ {"type": "memmap", "path": "dataset_0.fast_llm_dataset"}, {"type": "memmap", "path": "dataset_1.fast_llm_dataset"}, ], - "weights": [44883, 43910], + "weights": [47178, 46208], }, ], - "weights": [44883, 88793], + "weights": [47178, 93386], }, ), ( @@ -99,11 +99,11 @@ {"type": "memmap", "path": "dataset/dataset_1.fast_llm_dataset"}, {"type": "memmap", "path": "dataset/dataset_2.fast_llm_dataset"}, ], - "weights": [43910, 44883], + "weights": [46208, 47178], }, {"type": "memmap", "path": "dataset_0.fast_llm_dataset"}, ], - "weights": [88793, 44883], + "weights": [93386, 47178], }, ), ( @@ -130,12 +130,12 @@ {"type": "memmap", "path": "dataset1/dataset3/dataset_2.fast_llm_dataset"}, {"type": "memmap", "path": "dataset1/dataset_1.fast_llm_dataset"}, ], - "weights": [44883, 43910], + "weights": [47178, 46208], }, {"type": "memmap", "path": "dataset2/dataset_3.fast_llm_dataset"}, {"type": "memmap", "path": "dataset_0.fast_llm_dataset"}, ], - "weights": [88793, 43910, 44883], + "weights": [93386, 46208, 47178], }, ), ), diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 25e42fb97..884be4554 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -9,14 +9,14 @@ from tests.utils.global_variables import TOKENIZER_PATH GPT_FIM_SAMPLES = [ - [46, 10, 819, 19, 45, 88], - [45, 69, 17, 86, 38826, 15], - [86, 89, 32348, 64, 49152, 87], - [64, 17, 93, 78, 40, 1793], - [1793, 1, 1746, 38, 27, 58], - [86, 89, 37, 92, 76, 49152], - [86, 49152, 76, 29, 19, 89], - [86, 49152, 46, 83, 17211, 1], + [46, 10, 721, 19, 45, 88], + [45, 69, 17, 86, 92, 0], + [86, 89, 31, 27, 50256, 87], + [27, 0, 64, 17, 93, 78], + [78, 3955, 43, 1, 1395, 38], + [86, 89, 55, 93, 37, 50256], + [86, 50256, 37, 92, 76, 89], + [86, 89, 1, 50256, 87, 50256], ] diff --git a/tests/data/test_image_patch.py b/tests/data/test_image_patch.py index 9d613c2ec..34cb4f32f 100644 --- a/tests/data/test_image_patch.py +++ b/tests/data/test_image_patch.py @@ -14,7 +14,7 @@ from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_TEXT from tests.utils.dataset import get_test_dataset_with_image_patches -DATASET_WITH_IMAGE_PATCHES_TOKENS = [55750, 56809, 59145, 59145] +DATASET_WITH_IMAGE_PATCHES_TOKENS = [58021, 59080, 61416, 61416] DATASET_WITH_IMAGE_PATCHES_IMAGE_MD5 = { 27: [], 30: ["a2c34e404506fe664efcdb520642f260"], @@ -37,11 +37,11 @@ 87: [(17, 4), (15, 12)], } DATASET_WITH_IMAGE_PATCHES_SAMPLES = { - 27: [49152, 63, 82, 11, 27799, 49152], - 30: [49152, 31, 2327, (4, 1), 27, 1448, 62, 43, 49152], - 31: [49152, 60, 55, (2, 4), 80, 30, (3, 4), 85, 22, 18, 49152], - 77: [49152, 13736, 85, 52, 22, 46, 5, 11807, 49152], - 87: [49152, 52, (4, 1), 89, (4, 3), 75, 11, 71, 49152], + 27: [50256, 63, 82, 11, 7456, 50256], + 30: [50256, 31, 13038, (4, 1), 27, 8220, 62, 43, 50256], + 31: [50256, 60, 55, (2, 4), 80, 30, (3, 4), 85, 4790, 50256], + 77: [50256, 73, 44179, 52, 22, 46, 5, 8226, 50256], + 87: [50256, 52, (4, 1), 89, (4, 3), 75, 11, 71, 50256], } diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py index f0a35e9b8..efec8395c 100644 --- a/tests/data/test_loss_masking_spans.py +++ b/tests/data/test_loss_masking_spans.py @@ -11,13 +11,13 @@ from tests.utils.dataset import get_test_dataset_with_loss_masking_spans from tests.utils.global_variables import TOKENIZER_NAME -DATASET_WITH_SPAN_TOKENS = 45577 +DATASET_WITH_SPAN_TOKENS = 47782 DATASET_WITH_SPAN_SAMPLES = { - 27: [49152, 63, 82, 11, 27799, 49152], - 30: [49152, 31, 85, 78, 27, 1448, 62, 43, 49152], - 31: [49152, 60, 55, 80, 30, 85, 22, 18, 49152], - 77: [49152, 73, 80, 85, 52, 22, 46, 5, 88, 78, 49152], - 87: [49152, 52, 42536, 11, 71, 49152], + 27: [50256, 63, 82, 11, 7456, 50256], + 30: [50256, 31, 85, 78, 27, 8220, 62, 43, 50256], + 31: [50256, 60, 55, 80, 30, 85, 22, 18, 50256], + 77: [50256, 73, 80, 85, 52, 22, 46, 5, 88, 78, 50256], + 87: [50256, 52, 48274, 11, 71, 50256], } HF_LOSS_MASKING_SPANS = { 27: [[0, 1]], diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py index 36f8f77af..d3d46a1de 100644 --- a/tests/data/test_preference_spans.py +++ b/tests/data/test_preference_spans.py @@ -22,17 +22,17 @@ 87: ["Uz", "l", ",h"], } DATASET_WITH_PREFERENCE_SPAN_SAMPLES = { - 27: [49152, 63, 82, 11, 49152, 49152, 63, 27799, 49152], - 30: [49152, 31, 85, 78, 27, 34, 49152, 49152, 31, 85, 46, 62, 43, 49152], - 31: [49152, 60, 55, 80, 30, 85, 49152, 49152, 60, 55, 80, 30, 22, 18, 49152], - 77: [49152, 73, 80, 85, 52, 22, 46, 49152, 49152, 73, 5, 11807, 49152], - 87: [49152, 52, 89, 75, 49152, 49152, 52, 89, 11, 71, 49152], + 27: [50256, 63, 82, 11, 50256, 50256, 63, 7456, 50256], + 30: [50256, 31, 85, 78, 27, 34, 50256, 50256, 31, 85, 46, 62, 43, 50256], + 31: [50256, 60, 55, 80, 30, 85, 50256, 50256, 60, 55, 80, 30, 4790, 50256], + 77: [50256, 73, 44179, 52, 22, 46, 50256, 50256, 73, 5, 8226, 50256], + 87: [50256, 52, 89, 75, 50256, 50256, 52, 89, 11, 71, 50256], } TOKEN_PREFERENCE_SPANS = { 27: [(2, 5), (7, 9)], 30: [(3, 7), (10, 14)], - 31: [(5, 7), (12, 15)], - 77: [(2, 8), (10, 13)], + 31: [(5, 7), (12, 14)], + 77: [(2, 7), (9, 12)], 87: [(3, 5), (8, 11)], } diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index 763517cde..4a149ca64 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -21,7 +21,7 @@ from tests.utils.global_variables import DATASET_CACHE, TOKENIZER_NAME COMMON_DATASET_LENGTH = 1000 -COMMON_DATASET_TOKENS = 44883 +COMMON_DATASET_TOKENS = 47178 COMMON_DATASET_TEXT = { 27: "`s,uh", 30: "@vo Tokenizer: @pytest.mark.parametrize( ("spans", "expected_token_spans", "expected_tokens"), ( - ([], [], [7196, 5297]), # No span - ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), # Simple span - ([(2, 2)], [(1, 1)], [284, 47443, 5297]), # Empty span - ([(0, 11)], [(0, 2)], [7196, 5297]), # Full span - ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), # Two spans - ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), # Overlapping spans - ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), # Nested spans - ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), # Consecutive spans - ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), # Duplicate spans - ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), # Three spans + ([], [], [31373, 995]), # No span + ([(1, 3)], [(1, 2)], [71, 417, 5439, 995]), # Simple span + ([(2, 2)], [(1, 1)], [258, 18798, 995]), # Empty span + ([(0, 11)], [(0, 2)], [31373, 995]), # Full span + ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 695, 78, 220, 86, 1764]), # Two spans + ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 695, 78, 220, 86, 1764]), # Overlapping spans + ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 695, 78, 220, 86, 1764]), # Nested spans + ([(1, 5), (5, 7)], [(1, 2), (2, 3)], [71, 11109, 266, 1764]), # Consecutive spans + ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [258, 297, 78, 995]), # Duplicate spans + ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [258, 75, 5439, 24486, 81, 335]), # Three spans ), ) def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): @@ -79,14 +79,13 @@ def test_validate_chat_template_with_markers(common_tokenizer): ("messages", "expected_tokens", "expected_loss_masking_spans"), ( # Single turn: full assistant turn (Hello) is trainable - # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14 + # 17 tokens, loss mask spans cover 0-7 and 16 ( [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], - [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], - [(0, 7), (14, 15)], + [50256, 27, 7220, 29, 17250, 3556, 7220, 6927, 562, 10167, 29, 15496, 3556, 562, 10167, 29, 50256], + [(0, 7), (16, 17)], ), # Multi-turn: both assistant turns are fully trainable - # 27 tokens, trainable indices 7-13 and 19-25 ( [ {"role": "user", "content": "A"}, @@ -95,38 +94,41 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "D"}, ], [ - 49152, + 50256, 27, - 789, + 7220, 29, 32, - 750, - 789, - 2293, - 17822, + 3556, + 7220, + 6927, + 562, + 10167, 29, 33, - 750, - 17822, - 2293, - 789, + 3556, + 562, + 10167, + 6927, + 7220, 29, 34, - 750, - 789, - 2293, - 17822, + 3556, + 7220, + 6927, + 562, + 10167, 29, 35, - 750, - 17822, + 3556, + 562, + 10167, 29, - 49152, + 50256, ], - [(0, 7), (14, 19), (26, 27)], + [(0, 7), (16, 21), (30, 31)], ), # System + user + assistant: full assistant turn trainable - # 23 tokens, trainable indices 15-21 ( [ {"role": "system", "content": "You are helpful."}, @@ -134,41 +136,41 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "Hello"}, ], [ - 49152, + 50256, 27, - 3144, + 10057, 29, - 5815, - 1139, - 44569, - 6928, - 3144, - 2293, - 789, + 1639, + 389, + 7613, + 25970, + 10057, + 6927, + 7220, 29, - 16946, - 750, - 789, - 2293, - 17822, + 17250, + 3556, + 7220, + 6927, + 562, + 10167, 29, - 7371, - 750, - 17822, + 15496, + 3556, + 562, + 10167, 29, - 49152, + 50256, ], - [(0, 15), (22, 23)], + [(0, 15), (24, 25)], ), # User only: no trainable tokens - # 9 tokens, no trainable indices ( [{"role": "user", "content": "Hi"}], - [49152, 27, 789, 29, 16946, 750, 789, 29, 49152], + [50256, 27, 7220, 29, 17250, 3556, 7220, 29, 50256], [(0, 9)], ), - # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery) - # Trainable: indices 27-40, 49-62, 70-83 + # Long multi-turn (3 assistant responses with tags, tests span machinery) ( [ {"role": "system", "content": "You are a helpful assistant that answers questions."}, @@ -180,93 +182,92 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "The capital of Italy is Rome."}, ], [ - 49152, + 50256, 27, - 3144, + 10057, 29, - 5815, - 1139, - 373, - 44569, - 2424, - 11886, - 954, - 15737, - 14516, - 6928, - 3144, - 2293, - 789, + 1639, + 389, + 257, + 7613, + 8796, + 326, + 7429, + 2683, + 25970, + 10057, + 6927, + 7220, 29, - 13938, - 438, - 331, - 25016, - 457, - 12409, + 2061, + 318, + 262, + 3139, + 286, + 4881, + 30, + 3556, + 7220, + 6927, 562, - 35838, - 789, - 2293, - 17822, + 10167, 29, - 2111, - 25016, - 457, - 12409, + 464, + 3139, + 286, + 4881, + 318, + 6342, + 25970, 562, - 438, - 4235, - 280, - 6928, - 17822, - 2293, - 789, + 10167, + 6927, + 7220, 29, - 13938, - 5028, - 759, - 42226, - 35838, - 789, - 2293, - 17822, - 29, - 2111, - 25016, - 457, - 759, - 42226, - 438, - 29784, + 2061, + 546, + 4486, + 30, 3556, - 6928, - 17822, - 2293, - 789, + 7220, + 6927, + 562, + 10167, 29, - 1996, - 4413, - 3326, - 35838, - 789, - 2293, - 17822, + 464, + 3139, + 286, + 4486, + 318, + 11307, + 25970, + 562, + 10167, + 6927, + 7220, 29, - 2111, - 25016, - 457, - 4413, - 3326, - 438, - 613, - 1361, - 6928, - 17822, + 1870, + 8031, + 30, + 3556, + 7220, + 6927, + 562, + 10167, + 29, + 464, + 3139, + 286, + 8031, + 318, + 10598, + 25970, + 562, + 10167, 29, - 49152, + 50256, ], - [(0, 27), (41, 49), (63, 70), (84, 85)], + [(0, 26), (40, 48), (62, 69), (83, 84)], ), ), ) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index ce68e3f98..a2ea2f46e 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -18,7 +18,7 @@ def download_santacoder_tokenizer(): if not TOKENIZER_FILE.is_file(): import transformers - transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + transformers.AutoTokenizer.from_pretrained("gpt2").save_pretrained(TOKENIZER_PATH) def get_random_text( diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 20a0c7219..25de18072 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -35,7 +35,7 @@ def set_testing_global_variables(): # TODO: Fixtures TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -TOKENIZER_NAME = "bigcode/santacoder" +TOKENIZER_NAME = "gpt2" DATASET_CACHE = SHARED_RESULT_PATH / "dataset" From ccb8ce2b66d02428de31592fbf06f1b5e6e4d115 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Apr 2026 18:02:02 -0400 Subject: [PATCH 5/9] =?UTF-8?q?Simplify=20transformers=20v5=20compat;=20re?= =?UTF-8?q?name=20=5FTRANSFORMERS=5FV5=20=E2=86=92=20=5FTRANSFORMERS=5FV4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Deduplicate rope-type dispatch in LlamaAttentionConverter.import_config by normalizing rope_params/rope_theta from either checkpoint format first - Rename _TRANSFORMERS_V5 → _TRANSFORMERS_V4 (inverted flag) so v4 compat code is in `if _TRANSFORMERS_V4:` blocks — grep-and-delete to drop v4 - Flip all if/else so v5 code is the default path and v4 is the guarded branch - Import _TRANSFORMERS_V4 from config.py in huggingface.py; replace try/except with explicit if/else - Add comments for v5 changes that can't use the flag (TYPE_CHECKING guard, checkpoint format detection, model.model structure) Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/inference/config.py | 22 ++-- fast_llm/engine/inference/huggingface.py | 8 +- fast_llm/models/gpt/conversion/llama.py | 100 +++++++----------- fast_llm/models/gpt/huggingface.py | 3 + fast_llm/models/multimodal/huggingface.py | 3 + .../apriel2/conversion/llava/config.py | 2 + .../apriel2/conversion/llava/plan.py | 12 +-- .../apriel2/modeling_apriel2.py | 8 +- .../mtp_llama/modeling_mtp_llama.py | 4 +- .../test_apriel2/test_cache_contracts.py | 46 ++++---- .../test_apriel2/test_convert_from_llava.py | 8 +- tests/models/test_checkpoint.py | 1 + 12 files changed, 98 insertions(+), 119 deletions(-) diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index 7b88c9332..89dfdc9f7 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) class HuggingfaceModelConfig(transformers.PretrainedConfig): @@ -22,17 +22,15 @@ class HuggingfaceModelConfig(transformers.PretrainedConfig): fast_llm_config: FastLLMModelConfig | None = None use_cache: bool = True - if _TRANSFORMERS_V5: + def __post_init__(self, **kwargs): + # Needed for `to_diff_dict` (`__repr__`) + if self.fast_llm_config is None: + self.fast_llm_config = self.model_config_class() + super().__post_init__(**kwargs) + if self.dtype is not None: + assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch - def __post_init__(self, **kwargs): - # Needed for `to_diff_dict` (`__repr__`) - if self.fast_llm_config is None: - self.fast_llm_config = self.model_config_class() - super().__post_init__(**kwargs) - if self.dtype is not None: - assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch - - else: + if _TRANSFORMERS_V4: def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): # Needed for `to_diff_dict` (`__repr__`) @@ -103,7 +101,7 @@ def _get_config_dict( ) metadata = cls.model_config_class.load_metadata(pretrained) updates = {} - dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + dtype = kwargs.pop("dtype", None) or kwargs.pop("torch_dtype", None) # torch_dtype: transformers v4 if dtype is not None: updates[("distributed", "compute_dtype")] = dtype fast_llm_config = cls.model_config_class.from_metadata( diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 5fb7a60b5..8c6365a5f 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -11,17 +11,17 @@ from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.inference.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import _TRANSFORMERS_V4, HuggingfaceModelConfig from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.utils import Assert -try: - from transformers.initialization import no_init_weights as transformers_no_init_weights -except ImportError: +if _TRANSFORMERS_V4: from transformers.modeling_utils import no_init_weights as transformers_no_init_weights +else: + from transformers.initialization import no_init_weights as transformers_no_init_weights logger = logging.getLogger(__name__) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 243e09ade..f8f36dc23 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -32,7 +32,7 @@ from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div, safe_merge_dicts -_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) logger = logging.getLogger(__name__) @@ -192,68 +192,41 @@ def import_weight( class LlamaAttentionConverter: @classmethod def import_config(cls, config: dict) -> dict: - # transformers 5.x consolidates rope_theta + rope_scaling into rope_parameters - if "rope_parameters" in config: + # Normalize rope params to a single dict before dispatching on rope_type. + # transformers v5 consolidates rope_theta + rope_scaling into rope_parameters. + # transformers v4: rope_theta at top level, rope_scaling dict for non-default types. + # Note: detection is on checkpoint format, not transformers version — old checkpoints + # remain loadable with v5 transformers. + if "rope_parameters" in config: # transformers v5 rope_params = config["rope_parameters"] - rope_type = rope_params.get("rope_type", "default") - rotary_config = { - "type": rope_type, - "theta": rope_params["rope_theta"], - } - if rope_type == "default": - pass - elif rope_type == "llama3": - rotary_config.update( - { - "scale_factor": rope_params["factor"], - "low_frequency_factor": rope_params["low_freq_factor"], - "high_frequency_factor": rope_params["high_freq_factor"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - elif rope_type == "yarn": - rotary_config.update( - { - "attention_factor": rope_params["attention_factor"], - "beta_fast": rope_params["beta_fast"], - "beta_slow": rope_params["beta_slow"], - "original_context_length": rope_params["original_max_position_embeddings"], - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {rope_type}") + rope_theta = rope_params["rope_theta"] + else: # transformers v4 + rope_params = config.get("rope_scaling") or {} + rope_theta = config["rope_theta"] + rope_type = rope_params.get("rope_type", "default") + rotary_config = {"type": rope_type, "theta": rope_theta} + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_params["factor"], + "low_frequency_factor": rope_params["low_freq_factor"], + "high_frequency_factor": rope_params["high_freq_factor"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_params["attention_factor"], + "beta_fast": rope_params["beta_fast"], + "beta_slow": rope_params["beta_slow"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) else: - # transformers 4.x format: rope_theta at top level, rope_scaling separate - try: - rope_type = config["rope_scaling"]["rope_type"] - except (KeyError, TypeError): - rope_type = "default" - rotary_config = { - "type": rope_type, - "theta": config["rope_theta"], - } - if rope_type == "default": - pass - elif rope_type == "llama3": - rotary_config.update( - { - "scale_factor": config["rope_scaling"]["factor"], - "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], - "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], - } - ) - elif rope_type == "yarn": - rotary_config.update( - { - "attention_factor": config["rope_scaling"]["attention_factor"], - "beta_fast": config["rope_scaling"]["beta_fast"], - "beta_slow": config["rope_scaling"]["beta_slow"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], - } - ) - else: - raise NotImplementedError(f"Unsupported rotary type: {rope_type}") + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") out = { "rotary": rotary_config, "heads": config["num_attention_heads"], @@ -304,13 +277,12 @@ def export_config(cls, config: AttentionConfig) -> dict: "attention_bias": config.add_linear_biases, "attention_dropout": config.dropout, } - if _TRANSFORMERS_V5: - return {**common, "rope_parameters": rope_parameters} - else: + if _TRANSFORMERS_V4: out = {**common, "rope_theta": rope_parameters["rope_theta"]} if type(config.rotary) is not DefaultRotaryConfig: out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} return out + return {**common, "rope_parameters": rope_parameters} @classmethod def _check_config(cls, config: AttentionConfig) -> None: diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 1fcb3fc25..b72edef7b 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -21,6 +21,9 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): model_type = "fast_llm_gpt" model_config_class = GPTModelConfig + # transformers v5: PretrainedConfig is a dataclass, so redefining a field in a subclass + # would create a new dataclass field with a different default. Guard with TYPE_CHECKING + # so type checkers see the narrowed type without affecting the runtime dataclass layout. if typing.TYPE_CHECKING: fast_llm_config: GPTModelConfig diff --git a/fast_llm/models/multimodal/huggingface.py b/fast_llm/models/multimodal/huggingface.py index 93770b446..c3e3b539c 100644 --- a/fast_llm/models/multimodal/huggingface.py +++ b/fast_llm/models/multimodal/huggingface.py @@ -23,6 +23,9 @@ class HuggingfaceMultiModalModelConfig(HuggingfaceGPTModelConfig): model_type = "fast_llm_multi_modal" model_config_class = MultiModalModelConfig + # transformers v5: PretrainedConfig is a dataclass, so redefining a field in a subclass + # would create a new dataclass field with a different default. Guard with TYPE_CHECKING + # so type checkers see the narrowed type without affecting the runtime dataclass layout. if typing.TYPE_CHECKING: fast_llm_config: MultiModalModelConfig diff --git a/fast_llm_external_models/apriel2/conversion/llava/config.py b/fast_llm_external_models/apriel2/conversion/llava/config.py index 4d2e4d934..7a96d5414 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/config.py +++ b/fast_llm_external_models/apriel2/conversion/llava/config.py @@ -24,6 +24,7 @@ def convert_config(llava_config: dict) -> dict: hidden_size = text_config["hidden_size"] num_heads = text_config["num_attention_heads"] num_kv_heads = text_config["num_key_value_heads"] + # transformers v4 checkpoint: rope_theta at top level; v5: inside rope_parameters rope_theta = text_config.get("rope_theta") or text_config.get("rope_parameters", {}).get("rope_theta", 10000.0) # Use explicit head_dim if available (some models have head_dim != hidden_size // num_heads) # Note: MistralConfig.head_dim is None by default, so we must check for None explicitly @@ -98,6 +99,7 @@ def _convert_vision_config(llava_config: dict) -> dict: num_heads = vision_config["num_attention_heads"] num_layers = vision_config["num_hidden_layers"] intermediate_size = vision_config["intermediate_size"] + # transformers v4 checkpoint: rope_theta at top level; v5: inside rope_parameters rope_theta = vision_config.get("rope_theta") or vision_config.get("rope_parameters", {}).get("rope_theta", 10000.0) patch_size = vision_config["patch_size"] num_channels = vision_config["num_channels"] diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index e4f6d65bd..9303acc1e 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -1,7 +1,7 @@ """Llava to Apriel2 weight conversion plan.""" from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W -from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V5 +from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V4 def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: @@ -19,14 +19,14 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # 4.x: language_model.model.*, vision_tower.*, multi_modal_projector.* # 5.x: model.language_model.model.*, model.vision_tower.*, model.multi_modal_projector.* # lm_head stays as language_model.lm_head.weight in both versions. - if _TRANSFORMERS_V5: - llava_text_model = W("model", "language_model", "model") - llava_vision_tower = W("model", "vision_tower") - llava_projector = W("model", "multi_modal_projector") - else: + if _TRANSFORMERS_V4: llava_text_model = W("language_model", "model") llava_vision_tower = W("vision_tower") llava_projector = W("multi_modal_projector") + else: + llava_text_model = W("model", "language_model", "model") + llava_vision_tower = W("model", "vision_tower") + llava_projector = W("model", "multi_modal_projector") # Static mappings static_mappings = [ diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 4e3fe495e..b6e7d2a98 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -52,7 +52,7 @@ _gdn_fla_available = chunk_gated_delta_rule is not None and rms_norm_gated is not None _kda_fla_available = chunk_kda is not None -_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) try: @@ -2443,7 +2443,7 @@ def forward( class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} if _TRANSFORMERS_V5 else ["lm_head.weight"] + _tied_weights_keys = ["lm_head.weight"] if _TRANSFORMERS_V4 else {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: Apriel2TextConfig): super().__init__(config) @@ -3066,7 +3066,7 @@ class Apriel2ForConditionalGeneration(Apriel2PreTrainedModel, GenerationMixin): """ config_class = Apriel2Config - _tied_weights_keys = {} if _TRANSFORMERS_V5 else [] # No weight tying by default, but can be configured + _tied_weights_keys = [] if _TRANSFORMERS_V4 else {} # No weight tying by default, but can be configured def __init__(self, config: Apriel2Config): super().__init__(config) @@ -3077,7 +3077,7 @@ def __init__(self, config: Apriel2Config): # Handle weight tying if configured if config.tie_word_embeddings: self._tied_weights_keys = ( - {"lm_head.weight": "model.embed_tokens.weight"} if _TRANSFORMERS_V5 else ["lm_head.weight"] + ["lm_head.weight"] if _TRANSFORMERS_V4 else {"lm_head.weight": "model.embed_tokens.weight"} ) self.post_init() diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index 1b54f1c5f..40a6e4af8 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -30,7 +30,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MTPLlamaConfig" -_TRANSFORMERS_V5 = dataclasses.is_dataclass(transformers.PretrainedConfig) +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) class LlamaRMSNorm(nn.Module): @@ -787,7 +787,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} if _TRANSFORMERS_V5 else ["lm_head.weight"] + _tied_weights_keys = ["lm_head.weight"] if _TRANSFORMERS_V4 else {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py index 6ed54a57f..d8b12c586 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -28,7 +28,7 @@ import torch from fast_llm_external_models.apriel2.modeling_apriel2 import ( - _TRANSFORMERS_V5, + _TRANSFORMERS_V4, Apriel2Cache, _AttentionCache, ) @@ -118,7 +118,7 @@ def test_hf_mask_sizes_kv_length( # Verify HF's kv_length follows the expected formula cache_position = torch.arange(1) # Single token decode hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] ) expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0] assert hf_kv_len == expected_kv_len @@ -137,7 +137,7 @@ def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, cache_position = torch.arange(1) _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] ) assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0" @@ -257,7 +257,7 @@ def test_kv_offset_zero_before_window_full( cache_position = torch.arange(1) hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] ) # Verify HF returns 0 offset before window full @@ -282,7 +282,7 @@ def test_kv_offset_increases_after_window_full( cache_position = torch.arange(1) hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] ) # At window boundary, offset should be 1 @@ -297,7 +297,7 @@ def test_kv_offset_increases_after_window_full( apriel_sliding_cache.update(key.clone(), value.clone()) hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] ) expected_offset = i + 2 @@ -320,7 +320,7 @@ def test_kv_length_capped_at_window( apriel_sliding_cache.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position) + hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position if _TRANSFORMERS_V4 else cache_position.shape[0]) # HF returns window (window-1 cached + 1 query) assert hf_kv_len == window_size @@ -458,10 +458,10 @@ def test_get_mask_sizes_matches_dynamic_layer(self, attention_config): cache_position = torch.arange(1) hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] ) apr_kv_len, apr_kv_offset = cache.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position, layer_idx=0 + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0], layer_idx=0 ) assert apr_kv_len == hf_kv_len @@ -483,10 +483,10 @@ def test_get_mask_sizes_matches_sliding_layer(self, swa_config): cache_position = torch.arange(1) hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] ) apr_kv_len, apr_kv_offset = cache.get_mask_sizes( - cache_position.shape[0] if _TRANSFORMERS_V5 else cache_position, layer_idx=0 + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0], layer_idx=0 ) assert apr_kv_len == hf_kv_len @@ -530,11 +530,10 @@ def test_full_attention_decode_can_attend_to_all(self): kv_length = cache.cumulative_length + 1 kv_offset = 0 - if _TRANSFORMERS_V5: + if _TRANSFORMERS_V4: mask = sdpa_mask( batch_size=1, - q_length=cache_position.shape[0], - q_offset=cache_position[0].item(), + cache_position=cache_position, kv_length=kv_length, kv_offset=kv_offset, mask_function=causal_mask_function, @@ -542,7 +541,8 @@ def test_full_attention_decode_can_attend_to_all(self): else: mask = sdpa_mask( batch_size=1, - cache_position=cache_position, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), kv_length=kv_length, kv_offset=kv_offset, mask_function=causal_mask_function, @@ -572,11 +572,10 @@ def test_sliding_window_decode_respects_window(self, window_size): kv_offset = max(cumulative - window_size + 1, 0) kv_length = window_size - 1 + 1 # cached + query - if _TRANSFORMERS_V5: + if _TRANSFORMERS_V4: mask = sdpa_mask( batch_size=1, - q_length=cache_position.shape[0], - q_offset=cache_position[0].item(), + cache_position=cache_position, kv_length=kv_length, kv_offset=kv_offset, mask_function=sliding_window_causal_mask_function(window_size), @@ -584,7 +583,8 @@ def test_sliding_window_decode_respects_window(self, window_size): else: mask = sdpa_mask( batch_size=1, - cache_position=cache_position, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), kv_length=kv_length, kv_offset=kv_offset, mask_function=sliding_window_causal_mask_function(window_size), @@ -615,11 +615,10 @@ def test_prefill_has_causal_pattern(self): kv_length = cache.cumulative_length kv_offset = 0 - if _TRANSFORMERS_V5: + if _TRANSFORMERS_V4: mask = sdpa_mask( batch_size=1, - q_length=cache_position.shape[0], - q_offset=cache_position[0].item(), + cache_position=cache_position, kv_length=kv_length, kv_offset=kv_offset, mask_function=causal_mask_function, @@ -628,7 +627,8 @@ def test_prefill_has_causal_pattern(self): else: mask = sdpa_mask( batch_size=1, - cache_position=cache_position, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), kv_length=kv_length, kv_offset=kv_offset, mask_function=causal_mask_function, diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py index dc01a66b0..545ba7864 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py +++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py @@ -21,7 +21,7 @@ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery -from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V5, Apriel2ForConditionalGeneration +from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V4, Apriel2ForConditionalGeneration # ============================================================================= # Config Conversion Tests @@ -130,9 +130,9 @@ def test_plan_weight_values_unchanged(self, llava_pixtral_checkpoint): # Check specific weights are identical source_embed_key = ( - "model.language_model.model.embed_tokens.weight" - if _TRANSFORMERS_V5 - else "language_model.model.embed_tokens.weight" + "language_model.model.embed_tokens.weight" + if _TRANSFORMERS_V4 + else "model.language_model.model.embed_tokens.weight" ) source_embed = source_weights[source_embed_key] target_embed = apriel2_weights["model.embed_tokens.weight"] diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index cb20c9ca3..7b8d62bc5 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -391,6 +391,7 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic hidden_states = output.hidden_states + (output.logits,) # Llava models doesn't return vision hidden states, so we run the vision model directly instead. if model_testing_config.model_type == "multimodal": + # transformers v5: LlavaForConditionalGeneration wraps submodules under model.* vision_model = ( model.model if hasattr(model, "model") and hasattr(model.model, "vision_tower") else model ) From c92800635a52068f0a92eb64f787ffce21eb952a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Apr 2026 18:13:23 -0400 Subject: [PATCH 6/9] Replace W-object path chaining with explicit W() calls in plan.py Use tuple prefixes unpacked into W(...) instead of the / operator, keeping the _TRANSFORMERS_V4 branching for the path prefix. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/models/gpt/conversion/mtp_llama.py | 6 +- .../apriel2/conversion/llava/plan.py | 74 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index f681c4a24..cb9c5c1f2 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -58,7 +58,7 @@ def get_converters( converters += cls.block_converter_class.get_converters( config.decoder.last_block_config, f"multi_token_prediction.blocks.{prediction_distance-2}", - f"model.mtp_heads.{prediction_distance - 2}", + f"model.mtp_heads.{prediction_distance - 1}", ) converters += cls.normalization_converter_class.get_converters( config.head.normalization, @@ -73,7 +73,7 @@ class MTPLlamaDecoderConverter(LlamaDecoderConverter): def import_config(cls, config: dict) -> dict: return { "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"], + "num_blocks": config["num_hidden_layers"] - 1, } @classmethod @@ -82,7 +82,7 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict: Assert.custom(isinstance, config, FixedBlockSequenceConfig) return safe_merge_dicts( cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks}, + {"num_hidden_layers": config.num_blocks + 1}, ) diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index 9303acc1e..e4f147508 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -20,37 +20,37 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # 5.x: model.language_model.model.*, model.vision_tower.*, model.multi_modal_projector.* # lm_head stays as language_model.lm_head.weight in both versions. if _TRANSFORMERS_V4: - llava_text_model = W("language_model", "model") - llava_vision_tower = W("vision_tower") - llava_projector = W("multi_modal_projector") + text_model_prefix = ("language_model", "model") + vision_tower_prefix = ("vision_tower",) + projector_prefix = ("multi_modal_projector",) else: - llava_text_model = W("model", "language_model", "model") - llava_vision_tower = W("model", "vision_tower") - llava_projector = W("model", "multi_modal_projector") + text_model_prefix = ("model", "language_model", "model") + vision_tower_prefix = ("model", "vision_tower") + projector_prefix = ("model", "multi_modal_projector") # Static mappings static_mappings = [ - (llava_text_model / "embed_tokens" / "weight", W("model", "embed_tokens", "weight")), + (W(*text_model_prefix, "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), (W("language_model", "lm_head", "weight"), W("lm_head", "weight")), - (llava_text_model / "norm" / "weight", W("model", "norm", "weight")), + (W(*text_model_prefix, "norm", "weight"), W("model", "norm", "weight")), ( - llava_vision_tower / "patch_conv" / "weight", + W(*vision_tower_prefix, "patch_conv", "weight"), W("model", "vision_encoder", "embeddings", "patch_embeddings", "weight"), ), ( - llava_vision_tower / "ln_pre" / "weight", + W(*vision_tower_prefix, "ln_pre", "weight"), W("model", "vision_encoder", "embeddings", "normalization", "weight"), ), ( - llava_projector / "linear_1" / "weight", + W(*projector_prefix, "linear_1", "weight"), W("model", "vision_encoder", "adapter", "linear_1", "weight"), ), - (llava_projector / "linear_1" / "bias", W("model", "vision_encoder", "adapter", "linear_1", "bias")), + (W(*projector_prefix, "linear_1", "bias"), W("model", "vision_encoder", "adapter", "linear_1", "bias")), ( - llava_projector / "linear_2" / "weight", + W(*projector_prefix, "linear_2", "weight"), W("model", "vision_encoder", "adapter", "linear_2", "weight"), ), - (llava_projector / "linear_2" / "bias", W("model", "vision_encoder", "adapter", "linear_2", "bias")), + (W(*projector_prefix, "linear_2", "bias"), W("model", "vision_encoder", "adapter", "linear_2", "bias")), ] for src, tgt in static_mappings: @@ -58,47 +58,47 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # Text decoder layers for layer in range(num_text_layers): - llava_layer = llava_text_model / "layers" / layer - apriel_layer = W("model", "decoder", "blocks", layer) - # Attention projections for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: - src = llava_layer / "self_attn" / proj / "weight" - tgt = apriel_layer / "mixer" / proj / "weight" - mappings[tgt] = Ref(key=src) + mappings[W("model", "decoder", "blocks", layer, "mixer", proj, "weight")] = Ref( + key=W(*text_model_prefix, "layers", layer, "self_attn", proj, "weight") + ) # MLP projections for proj in ["gate_proj", "up_proj", "down_proj"]: - src = llava_layer / "mlp" / proj / "weight" - tgt = apriel_layer / "mlp" / proj / "weight" - mappings[tgt] = Ref(key=src) + mappings[W("model", "decoder", "blocks", layer, "mlp", proj, "weight")] = Ref( + key=W(*text_model_prefix, "layers", layer, "mlp", proj, "weight") + ) # Layer norms - mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=llava_layer / "input_layernorm" / "weight") - mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref( - key=llava_layer / "post_attention_layernorm" / "weight" + mappings[W("model", "decoder", "blocks", layer, "input_layernorm", "weight")] = Ref( + key=W(*text_model_prefix, "layers", layer, "input_layernorm", "weight") + ) + mappings[W("model", "decoder", "blocks", layer, "post_attention_layernorm", "weight")] = Ref( + key=W(*text_model_prefix, "layers", layer, "post_attention_layernorm", "weight") ) # Vision encoder layers for layer in range(num_vision_layers): - llava_layer = llava_vision_tower / "transformer" / "layers" / layer - apriel_layer = W("model", "vision_encoder", "encoder", "blocks", layer) - # Attention projections for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: - src = llava_layer / "attention" / proj / "weight" - tgt = apriel_layer / "mixer" / proj / "weight" - mappings[tgt] = Ref(key=src) + mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "mixer", proj, "weight")] = Ref( + key=W(*vision_tower_prefix, "transformer", "layers", layer, "attention", proj, "weight") + ) # MLP projections (llava uses feed_forward, apriel uses mlp) for proj in ["gate_proj", "up_proj", "down_proj"]: - src = llava_layer / "feed_forward" / proj / "weight" - tgt = apriel_layer / "mlp" / proj / "weight" - mappings[tgt] = Ref(key=src) + mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "mlp", proj, "weight")] = Ref( + key=W(*vision_tower_prefix, "transformer", "layers", layer, "feed_forward", proj, "weight") + ) # Layer norms (different naming) - mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=llava_layer / "attention_norm" / "weight") - mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref(key=llava_layer / "ffn_norm" / "weight") + mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "input_layernorm", "weight")] = Ref( + key=W(*vision_tower_prefix, "transformer", "layers", layer, "attention_norm", "weight") + ) + mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "post_attention_layernorm", "weight")] = Ref( + key=W(*vision_tower_prefix, "transformer", "layers", layer, "ffn_norm", "weight") + ) return ExprPlan( mappings=mappings, From 7567cb871320323d20fc8be8fa650086c45e42cb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Apr 2026 18:15:11 -0400 Subject: [PATCH 7/9] Restore loop structure in plan.py; use prefix tuples only at layer init Keep llava_layer/apriel_layer intermediate variables (with / operator) in loops; only the layer root W() calls use *prefix unpacking. Co-Authored-By: Claude Sonnet 4.6 --- .../apriel2/conversion/llava/plan.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index e4f147508..3fd85d8cb 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -58,47 +58,47 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # Text decoder layers for layer in range(num_text_layers): + llava_layer = W(*text_model_prefix, "layers", layer) + apriel_layer = W("model", "decoder", "blocks", layer) + # Attention projections for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: - mappings[W("model", "decoder", "blocks", layer, "mixer", proj, "weight")] = Ref( - key=W(*text_model_prefix, "layers", layer, "self_attn", proj, "weight") - ) + src = llava_layer / "self_attn" / proj / "weight" + tgt = apriel_layer / "mixer" / proj / "weight" + mappings[tgt] = Ref(key=src) # MLP projections for proj in ["gate_proj", "up_proj", "down_proj"]: - mappings[W("model", "decoder", "blocks", layer, "mlp", proj, "weight")] = Ref( - key=W(*text_model_prefix, "layers", layer, "mlp", proj, "weight") - ) + src = llava_layer / "mlp" / proj / "weight" + tgt = apriel_layer / "mlp" / proj / "weight" + mappings[tgt] = Ref(key=src) # Layer norms - mappings[W("model", "decoder", "blocks", layer, "input_layernorm", "weight")] = Ref( - key=W(*text_model_prefix, "layers", layer, "input_layernorm", "weight") - ) - mappings[W("model", "decoder", "blocks", layer, "post_attention_layernorm", "weight")] = Ref( - key=W(*text_model_prefix, "layers", layer, "post_attention_layernorm", "weight") + mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=llava_layer / "input_layernorm" / "weight") + mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref( + key=llava_layer / "post_attention_layernorm" / "weight" ) # Vision encoder layers for layer in range(num_vision_layers): + llava_layer = W(*vision_tower_prefix, "transformer", "layers", layer) + apriel_layer = W("model", "vision_encoder", "encoder", "blocks", layer) + # Attention projections for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: - mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "mixer", proj, "weight")] = Ref( - key=W(*vision_tower_prefix, "transformer", "layers", layer, "attention", proj, "weight") - ) + src = llava_layer / "attention" / proj / "weight" + tgt = apriel_layer / "mixer" / proj / "weight" + mappings[tgt] = Ref(key=src) # MLP projections (llava uses feed_forward, apriel uses mlp) for proj in ["gate_proj", "up_proj", "down_proj"]: - mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "mlp", proj, "weight")] = Ref( - key=W(*vision_tower_prefix, "transformer", "layers", layer, "feed_forward", proj, "weight") - ) + src = llava_layer / "feed_forward" / proj / "weight" + tgt = apriel_layer / "mlp" / proj / "weight" + mappings[tgt] = Ref(key=src) # Layer norms (different naming) - mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "input_layernorm", "weight")] = Ref( - key=W(*vision_tower_prefix, "transformer", "layers", layer, "attention_norm", "weight") - ) - mappings[W("model", "vision_encoder", "encoder", "blocks", layer, "post_attention_layernorm", "weight")] = Ref( - key=W(*vision_tower_prefix, "transformer", "layers", layer, "ffn_norm", "weight") - ) + mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=llava_layer / "attention_norm" / "weight") + mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref(key=llava_layer / "ffn_norm" / "weight") return ExprPlan( mappings=mappings, From 3c8241e49350826f33bf0baa385e60e555a38b52 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Apr 2026 18:18:02 -0400 Subject: [PATCH 8/9] fix --- fast_llm/models/gpt/conversion/mtp_llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index cb9c5c1f2..f681c4a24 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -58,7 +58,7 @@ def get_converters( converters += cls.block_converter_class.get_converters( config.decoder.last_block_config, f"multi_token_prediction.blocks.{prediction_distance-2}", - f"model.mtp_heads.{prediction_distance - 1}", + f"model.mtp_heads.{prediction_distance - 2}", ) converters += cls.normalization_converter_class.get_converters( config.head.normalization, @@ -73,7 +73,7 @@ class MTPLlamaDecoderConverter(LlamaDecoderConverter): def import_config(cls, config: dict) -> dict: return { "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"] - 1, + "num_blocks": config["num_hidden_layers"], } @classmethod @@ -82,7 +82,7 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict: Assert.custom(isinstance, config, FixedBlockSequenceConfig) return safe_merge_dicts( cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks + 1}, + {"num_hidden_layers": config.num_blocks}, ) From 6e4a477aba60a9ceff292f6e7a440ea62400d057 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Apr 2026 19:54:11 -0400 Subject: [PATCH 9/9] Fix mtp llama test --- fast_llm/models/gpt/huggingface.py | 24 ++++++++++++++++++++++-- fast_llm/models/gpt/model.py | 6 +++--- tests/models/test_checkpoint.py | 7 ++++++- tests/utils/model_configs.py | 2 +- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index b72edef7b..77111e9f4 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -46,6 +46,7 @@ def inner_forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + return_all_prediction_heads: bool = False, ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: return self._inner_forward( self._get_batch(input_ids, attention_mask), @@ -57,6 +58,7 @@ def inner_forward( output_attentions, output_hidden_states, return_dict, + return_all_prediction_heads, ) def _get_batch( @@ -94,6 +96,7 @@ def _inner_forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + return_all_prediction_heads: bool = False, ) -> transformers.modeling_outputs.CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -127,7 +130,18 @@ def _inner_forward( } # TODO: Handle MTP. - logits = hidden_states.pop("head.logits") + + self.fast_llm_base_model.head.module_name + logits = hidden_states.pop(f"{self.fast_llm_base_model.head.module_name}.logits") + if return_all_prediction_heads: + logits = torch.stack( + [logits] + + [ + hidden_states.pop(f"{head.module_name}.logits") + for head in self.fast_llm_base_model.multi_token_prediction.heads + ], + dim=-2, + ) output = transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, @@ -156,7 +170,13 @@ def _get_input( output_hidden_states = ( [self.fast_llm_base_model.embeddings.module_name + "$"] + [layer.module_name + "$" for layer in self.fast_llm_base_model.decoder][:-1] - + [self.fast_llm_base_model.head.final_norm.module_name + "$"] + + [ + head.final_norm.module_name + "$" + for head in [ + self.fast_llm_base_model.head, + *self.fast_llm_base_model.multi_token_prediction.heads, + ] + ] ) # This needs to be set before preprocessing so it propagates to layers with namespace. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 83abaca21..2e9b4365b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -58,16 +58,16 @@ def preprocess_batch( Assert.empty(kwargs.keys() & extra_kwargs.keys()) kwargs.update(extra_kwargs) if phase == PhaseType.inference: - kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) + kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"(?:.*\.)?logits.*$")) if not model_input.is_meta: for name, reference_model in self._reference_models.items(): output_hidden_states = set() if name in self._head_reference_models: - output_hidden_states.add(re.compile(r"head\..*logits.*$")) + output_hidden_states.add(re.compile(r"(?:.*\.)?logits.*$")) if name in self._decoder_reference_models: # TODO: Get the actual names - output_hidden_states.add(re.compile(r"decoder\.\d+\.mixer_output$")) + output_hidden_states.add(re.compile(r"(?:.*\.)?decoder\.\d+\.mixer_output$")) assert len(output_hidden_states) >= 1 reference_model_input = dataclasses.replace( model_input, diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 7b8d62bc5..0b4dbafc1 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -333,6 +333,8 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic device=testing_device, ) kwargs = {"output_hidden_states": True} + if is_mtp := (model_ref.fast_llm_base_model.config.head.prediction_heads > 1): + kwargs["return_all_prediction_heads"] = True if model_testing_config.model_type == "multimodal": kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(testing_device) kwargs["image_sizes"] = torch.tensor( @@ -388,7 +390,10 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic if model_testing_config.model_type == "multimodal" and hasattr(model, "vision_encoder"): kwargs["output_vision_hidden_states"] = True output = model(test_input, **kwargs) - hidden_states = output.hidden_states + (output.logits,) + # Fast-LLM doesn't concatenate the head hidden states. + hidden_states = ( + output.hidden_states[:-1] + output.hidden_states[-1].unbind(-2) if is_mtp else output.hidden_states + ) + (output.logits,) # Llava models doesn't return vision hidden states, so we run the vision model directly instead. if model_testing_config.model_type == "multimodal": # transformers v5: LlavaForConditionalGeneration wraps submodules under model.* diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 0f89d9323..7de1de2da 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -475,7 +475,7 @@ def update_and_add_testing_config( "llama", "mtp_llama", updates={ - ("model", "base_model", "decoder", "num_blocks"): 1, + ("model", "base_model", "decoder", "num_blocks"): 2, ("model", "base_model", "head", "prediction_heads"): 2, }, # Megatron doesn't support multi-token prediction.