diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index b0b3b33a0..89dfdc9f7 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -1,4 +1,5 @@ import copy +import dataclasses import logging import os import pathlib @@ -12,20 +13,32 @@ logger = logging.getLogger(__name__) +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) + class HuggingfaceModelConfig(transformers.PretrainedConfig): model_type = "fast_llm" model_config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig + fast_llm_config: FastLLMModelConfig | None = None + use_cache: bool = True - def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): + def __post_init__(self, **kwargs): # Needed for `to_diff_dict` (`__repr__`) - if fast_llm_config is None: - fast_llm_config = self.model_config_class() - self.fast_llm_config = fast_llm_config - self.use_cache = kwargs.pop("use_cache", True) - super().__init__(**kwargs) - if self.torch_dtype is not None: - assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch + if self.fast_llm_config is None: + self.fast_llm_config = self.model_config_class() + super().__post_init__(**kwargs) + if self.dtype is not None: + assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch + + if _TRANSFORMERS_V4: + + def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): + # Needed for `to_diff_dict` (`__repr__`) + self.fast_llm_config = fast_llm_config if fast_llm_config is not None else self.model_config_class() + self.use_cache = kwargs.pop("use_cache", True) + super().__init__(**kwargs) + if self.torch_dtype is not None: + assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None: # Hack the method to save at the right place. @@ -88,9 +101,9 @@ def _get_config_dict( ) metadata = cls.model_config_class.load_metadata(pretrained) updates = {} - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None: - updates[("distributed", "compute_dtype")] = torch_dtype + dtype = kwargs.pop("dtype", None) or kwargs.pop("torch_dtype", None) # torch_dtype: transformers v4 + if dtype is not None: + updates[("distributed", "compute_dtype")] = dtype fast_llm_config = cls.model_config_class.from_metadata( pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 67be46558..8c6365a5f 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -11,13 +11,19 @@ from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.inference.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import _TRANSFORMERS_V4, HuggingfaceModelConfig from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.utils import Assert +if _TRANSFORMERS_V4: + from transformers.modeling_utils import no_init_weights as transformers_no_init_weights +else: + from transformers.initialization import no_init_weights as transformers_no_init_weights + + logger = logging.getLogger(__name__) @@ -38,7 +44,7 @@ def __init__( **kwargs, ): if config is None: - config = self.config_class(fast_llm_model.config) + config = self.config_class(fast_llm_config=fast_llm_model.config) assert self.runner_class.model_class.config_class is config.model_config_class assert config.fast_llm_config is fast_llm_model.config @@ -70,7 +76,7 @@ def __init__( # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model - with transformers.modeling_utils.no_init_weights(): + with transformers_no_init_weights(): self.post_init() if fast_llm_model.config.multi_stage.zero_stage == 3: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 491ddde6e..f8f36dc23 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -1,7 +1,9 @@ +import dataclasses import logging import typing import torch +import transformers from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( @@ -30,6 +32,8 @@ from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div, safe_merge_dicts +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) + logger = logging.getLogger(__name__) @@ -188,32 +192,37 @@ def import_weight( class LlamaAttentionConverter: @classmethod def import_config(cls, config: dict) -> dict: - try: - rope_type = config["rope_scaling"]["rope_type"] - except (KeyError, TypeError): - rope_type = "default" - rotary_config = { - "type": rope_type, - "theta": config["rope_theta"], - } + # Normalize rope params to a single dict before dispatching on rope_type. + # transformers v5 consolidates rope_theta + rope_scaling into rope_parameters. + # transformers v4: rope_theta at top level, rope_scaling dict for non-default types. + # Note: detection is on checkpoint format, not transformers version — old checkpoints + # remain loadable with v5 transformers. + if "rope_parameters" in config: # transformers v5 + rope_params = config["rope_parameters"] + rope_theta = rope_params["rope_theta"] + else: # transformers v4 + rope_params = config.get("rope_scaling") or {} + rope_theta = config["rope_theta"] + rope_type = rope_params.get("rope_type", "default") + rotary_config = {"type": rope_type, "theta": rope_theta} if rope_type == "default": pass elif rope_type == "llama3": rotary_config.update( { - "scale_factor": config["rope_scaling"]["factor"], - "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], - "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], + "scale_factor": rope_params["factor"], + "low_frequency_factor": rope_params["low_freq_factor"], + "high_frequency_factor": rope_params["high_freq_factor"], + "original_context_length": rope_params["original_max_position_embeddings"], } ) elif rope_type == "yarn": rotary_config.update( { - "attention_factor": config["rope_scaling"]["attention_factor"], - "beta_fast": config["rope_scaling"]["beta_fast"], - "beta_slow": config["rope_scaling"]["beta_slow"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], + "attention_factor": rope_params["attention_factor"], + "beta_fast": rope_params["beta_fast"], + "beta_slow": rope_params["beta_slow"], + "original_context_length": rope_params["original_max_position_embeddings"], } ) else: @@ -235,36 +244,45 @@ def import_config(cls, config: dict) -> dict: def export_config(cls, config: AttentionConfig) -> dict: cls._check_config(config) Assert.eq(config.softmax_scale_power, 0.5) - out = { - "num_attention_heads": config.heads, - "num_key_value_heads": config.head_groups, - "head_dim": config.head_size, - "attention_bias": config.add_linear_biases, - "attention_dropout": config.dropout, - "rope_theta": config.rotary.theta, - } + rope_parameters = {"rope_theta": config.rotary.theta} if type(config.rotary) is DefaultRotaryConfig: - pass + rope_parameters["rope_type"] = "default" elif type(config.rotary) is Llama3RotaryConfig: - out["rope_scaling"] = { - "rope_type": "llama3", - "factor": config.rotary.scale_factor, - "low_freq_factor": config.rotary.low_frequency_factor, - "high_freq_factor": config.rotary.high_frequency_factor, - "original_max_position_embeddings": config.rotary.original_context_length, - } + rope_parameters.update( + { + "rope_type": "llama3", + "factor": config.rotary.scale_factor, + "low_freq_factor": config.rotary.low_frequency_factor, + "high_freq_factor": config.rotary.high_frequency_factor, + "original_max_position_embeddings": config.rotary.original_context_length, + } + ) elif type(config.rotary) is YarnRotaryConfig: - out["rope_scaling"] = { - "rope_type": "yarn", - "attention_factor": config.rotary.attention_factor, - "beta_fast": config.rotary.beta_fast, - "beta_slow": config.rotary.beta_slow, - "original_max_position_embeddings": config.rotary.original_context_length, - } + rope_parameters.update( + { + "rope_type": "yarn", + "attention_factor": config.rotary.attention_factor, + "beta_fast": config.rotary.beta_fast, + "beta_slow": config.rotary.beta_slow, + "original_max_position_embeddings": config.rotary.original_context_length, + } + ) else: raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - return out + common = { + "num_attention_heads": config.heads, + "num_key_value_heads": config.head_groups, + "head_dim": config.head_size, + "attention_bias": config.add_linear_biases, + "attention_dropout": config.dropout, + } + if _TRANSFORMERS_V4: + out = {**common, "rope_theta": rope_parameters["rope_theta"]} + if type(config.rotary) is not DefaultRotaryConfig: + out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"} + return out + return {**common, "rope_parameters": rope_parameters} @classmethod def _check_config(cls, config: AttentionConfig) -> None: diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index cb9c5c1f2..f681c4a24 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -58,7 +58,7 @@ def get_converters( converters += cls.block_converter_class.get_converters( config.decoder.last_block_config, f"multi_token_prediction.blocks.{prediction_distance-2}", - f"model.mtp_heads.{prediction_distance - 1}", + f"model.mtp_heads.{prediction_distance - 2}", ) converters += cls.normalization_converter_class.get_converters( config.head.normalization, @@ -73,7 +73,7 @@ class MTPLlamaDecoderConverter(LlamaDecoderConverter): def import_config(cls, config: dict) -> dict: return { "block": cls.block_converter_class.import_config(config), - "num_blocks": config["num_hidden_layers"] - 1, + "num_blocks": config["num_hidden_layers"], } @classmethod @@ -82,7 +82,7 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict: Assert.custom(isinstance, config, FixedBlockSequenceConfig) return safe_merge_dicts( cls.block_converter_class.export_config(config.block), - {"num_hidden_layers": config.num_blocks + 1}, + {"num_hidden_layers": config.num_blocks}, ) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 55c30c7ee..77111e9f4 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -20,7 +20,12 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): model_type = "fast_llm_gpt" model_config_class = GPTModelConfig - fast_llm_config: GPTModelConfig + + # transformers v5: PretrainedConfig is a dataclass, so redefining a field in a subclass + # would create a new dataclass field with a different default. Guard with TYPE_CHECKING + # so type checkers see the narrowed type without affecting the runtime dataclass layout. + if typing.TYPE_CHECKING: + fast_llm_config: GPTModelConfig class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): @@ -41,6 +46,7 @@ def inner_forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + return_all_prediction_heads: bool = False, ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: return self._inner_forward( self._get_batch(input_ids, attention_mask), @@ -52,6 +58,7 @@ def inner_forward( output_attentions, output_hidden_states, return_dict, + return_all_prediction_heads, ) def _get_batch( @@ -89,6 +96,7 @@ def _inner_forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + return_all_prediction_heads: bool = False, ) -> transformers.modeling_outputs.CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -122,7 +130,18 @@ def _inner_forward( } # TODO: Handle MTP. - logits = hidden_states.pop("head.logits") + + self.fast_llm_base_model.head.module_name + logits = hidden_states.pop(f"{self.fast_llm_base_model.head.module_name}.logits") + if return_all_prediction_heads: + logits = torch.stack( + [logits] + + [ + hidden_states.pop(f"{head.module_name}.logits") + for head in self.fast_llm_base_model.multi_token_prediction.heads + ], + dim=-2, + ) output = transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, @@ -151,7 +170,13 @@ def _get_input( output_hidden_states = ( [self.fast_llm_base_model.embeddings.module_name + "$"] + [layer.module_name + "$" for layer in self.fast_llm_base_model.decoder][:-1] - + [self.fast_llm_base_model.head.final_norm.module_name + "$"] + + [ + head.final_norm.module_name + "$" + for head in [ + self.fast_llm_base_model.head, + *self.fast_llm_base_model.multi_token_prediction.heads, + ] + ] ) # This needs to be set before preprocessing so it propagates to layers with namespace. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 83abaca21..2e9b4365b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -58,16 +58,16 @@ def preprocess_batch( Assert.empty(kwargs.keys() & extra_kwargs.keys()) kwargs.update(extra_kwargs) if phase == PhaseType.inference: - kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) + kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"(?:.*\.)?logits.*$")) if not model_input.is_meta: for name, reference_model in self._reference_models.items(): output_hidden_states = set() if name in self._head_reference_models: - output_hidden_states.add(re.compile(r"head\..*logits.*$")) + output_hidden_states.add(re.compile(r"(?:.*\.)?logits.*$")) if name in self._decoder_reference_models: # TODO: Get the actual names - output_hidden_states.add(re.compile(r"decoder\.\d+\.mixer_output$")) + output_hidden_states.add(re.compile(r"(?:.*\.)?decoder\.\d+\.mixer_output$")) assert len(output_hidden_states) >= 1 reference_model_input = dataclasses.replace( model_input, diff --git a/fast_llm/models/multimodal/huggingface.py b/fast_llm/models/multimodal/huggingface.py index 8bf14d715..c3e3b539c 100644 --- a/fast_llm/models/multimodal/huggingface.py +++ b/fast_llm/models/multimodal/huggingface.py @@ -22,7 +22,12 @@ class HuggingfaceMultiModalModelConfig(HuggingfaceGPTModelConfig): model_type = "fast_llm_multi_modal" model_config_class = MultiModalModelConfig - fast_llm_config: MultiModalModelConfig + + # transformers v5: PretrainedConfig is a dataclass, so redefining a field in a subclass + # would create a new dataclass field with a different default. Guard with TYPE_CHECKING + # so type checkers see the narrowed type without affecting the runtime dataclass layout. + if typing.TYPE_CHECKING: + fast_llm_config: MultiModalModelConfig class HuggingfaceMultiModalModelForCausalLM(HuggingfaceGPTModelForCausalLM): diff --git a/fast_llm_external_models/apriel2/conversion/llava/config.py b/fast_llm_external_models/apriel2/conversion/llava/config.py index ac8f70dba..7a96d5414 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/config.py +++ b/fast_llm_external_models/apriel2/conversion/llava/config.py @@ -24,7 +24,8 @@ def convert_config(llava_config: dict) -> dict: hidden_size = text_config["hidden_size"] num_heads = text_config["num_attention_heads"] num_kv_heads = text_config["num_key_value_heads"] - rope_theta = text_config["rope_theta"] + # transformers v4 checkpoint: rope_theta at top level; v5: inside rope_parameters + rope_theta = text_config.get("rope_theta") or text_config.get("rope_parameters", {}).get("rope_theta", 10000.0) # Use explicit head_dim if available (some models have head_dim != hidden_size // num_heads) # Note: MistralConfig.head_dim is None by default, so we must check for None explicitly head_dim = text_config.get("head_dim") @@ -98,7 +99,8 @@ def _convert_vision_config(llava_config: dict) -> dict: num_heads = vision_config["num_attention_heads"] num_layers = vision_config["num_hidden_layers"] intermediate_size = vision_config["intermediate_size"] - rope_theta = vision_config["rope_theta"] + # transformers v4 checkpoint: rope_theta at top level; v5: inside rope_parameters + rope_theta = vision_config.get("rope_theta") or vision_config.get("rope_parameters", {}).get("rope_theta", 10000.0) patch_size = vision_config["patch_size"] num_channels = vision_config["num_channels"] # Use explicit head_dim if available diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index a97e46c1a..3fd85d8cb 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -1,6 +1,7 @@ """Llava to Apriel2 weight conversion plan.""" from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W +from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V4 def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: @@ -14,26 +15,42 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: num_text_layers = llava_config.get("text_config", {}).get("num_hidden_layers", 0) num_vision_layers = llava_config.get("vision_config", {}).get("num_hidden_layers", 0) + # In transformers 5.x, LlavaForConditionalGeneration adds a "model." prefix to most submodules. + # 4.x: language_model.model.*, vision_tower.*, multi_modal_projector.* + # 5.x: model.language_model.model.*, model.vision_tower.*, model.multi_modal_projector.* + # lm_head stays as language_model.lm_head.weight in both versions. + if _TRANSFORMERS_V4: + text_model_prefix = ("language_model", "model") + vision_tower_prefix = ("vision_tower",) + projector_prefix = ("multi_modal_projector",) + else: + text_model_prefix = ("model", "language_model", "model") + vision_tower_prefix = ("model", "vision_tower") + projector_prefix = ("model", "multi_modal_projector") + # Static mappings static_mappings = [ - (W("language_model", "model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), + (W(*text_model_prefix, "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), (W("language_model", "lm_head", "weight"), W("lm_head", "weight")), - (W("language_model", "model", "norm", "weight"), W("model", "norm", "weight")), + (W(*text_model_prefix, "norm", "weight"), W("model", "norm", "weight")), ( - W("vision_tower", "patch_conv", "weight"), + W(*vision_tower_prefix, "patch_conv", "weight"), W("model", "vision_encoder", "embeddings", "patch_embeddings", "weight"), ), - (W("vision_tower", "ln_pre", "weight"), W("model", "vision_encoder", "embeddings", "normalization", "weight")), ( - W("multi_modal_projector", "linear_1", "weight"), + W(*vision_tower_prefix, "ln_pre", "weight"), + W("model", "vision_encoder", "embeddings", "normalization", "weight"), + ), + ( + W(*projector_prefix, "linear_1", "weight"), W("model", "vision_encoder", "adapter", "linear_1", "weight"), ), - (W("multi_modal_projector", "linear_1", "bias"), W("model", "vision_encoder", "adapter", "linear_1", "bias")), + (W(*projector_prefix, "linear_1", "bias"), W("model", "vision_encoder", "adapter", "linear_1", "bias")), ( - W("multi_modal_projector", "linear_2", "weight"), + W(*projector_prefix, "linear_2", "weight"), W("model", "vision_encoder", "adapter", "linear_2", "weight"), ), - (W("multi_modal_projector", "linear_2", "bias"), W("model", "vision_encoder", "adapter", "linear_2", "bias")), + (W(*projector_prefix, "linear_2", "bias"), W("model", "vision_encoder", "adapter", "linear_2", "bias")), ] for src, tgt in static_mappings: @@ -41,7 +58,7 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # Text decoder layers for layer in range(num_text_layers): - llava_layer = W("language_model", "model", "layers", layer) + llava_layer = W(*text_model_prefix, "layers", layer) apriel_layer = W("model", "decoder", "blocks", layer) # Attention projections @@ -64,7 +81,7 @@ def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: # Vision encoder layers for layer in range(num_vision_layers): - llava_layer = W("vision_tower", "transformer", "layers", layer) + llava_layer = W(*vision_tower_prefix, "transformer", "layers", layer) apriel_layer = W("model", "vision_encoder", "encoder", "blocks", layer) # Attention projections diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 9e82dfc4f..b6e7d2a98 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -1,5 +1,6 @@ """Apriel2 HuggingFace model implementation.""" +import dataclasses import math import random from types import SimpleNamespace @@ -7,6 +8,7 @@ import torch import torch.nn.functional as F +import transformers from einops import rearrange, repeat from torch import nn from transformers import GenerationMixin, PreTrainedModel @@ -48,6 +50,10 @@ fused_recurrent_kda = None fused_kda_gate = None +_gdn_fla_available = chunk_gated_delta_rule is not None and rms_norm_gated is not None +_kda_fla_available = chunk_kda is not None +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) + try: from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn @@ -289,7 +295,7 @@ def get_mask_sizes(self, cache_position, layer_idx): For SSM/linear layers: kv_offset = 0, kv_length = query_length (no KV cache to attend to) """ - query_length = cache_position.shape[0] + query_length = cache_position if isinstance(cache_position, int) else cache_position.shape[0] layer = self.layers[layer_idx] # Handle stochastic layers by getting the active mixer's cache @@ -781,6 +787,7 @@ def setup( rope_theta=rope_theta, image_size=rotary_config_dict["max_image_size"], patch_size=rotary_config_dict["patch_size"], + rope_parameters={"rope_theta": rope_theta, "rope_type": "default"}, ) return nn.ModuleDict({"rotary_emb": PixtralRotaryEmbedding(config=rotary_config)}) @@ -794,6 +801,7 @@ def setup( hidden_size=hidden_size, num_attention_heads=num_heads, partial_rotary_factor=1.0, + rope_parameters={"rope_theta": rope_theta, "rope_type": "default"}, ) return nn.ModuleDict({"rotary_emb": MistralRotaryEmbedding(config=rotary_config)}) @@ -2435,7 +2443,7 @@ def forward( class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = ["lm_head.weight"] if _TRANSFORMERS_V4 else {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: Apriel2TextConfig): super().__init__(config) @@ -3058,7 +3066,7 @@ class Apriel2ForConditionalGeneration(Apriel2PreTrainedModel, GenerationMixin): """ config_class = Apriel2Config - _tied_weights_keys = [] # No weight tying by default, but can be configured + _tied_weights_keys = [] if _TRANSFORMERS_V4 else {} # No weight tying by default, but can be configured def __init__(self, config: Apriel2Config): super().__init__(config) @@ -3068,7 +3076,9 @@ def __init__(self, config: Apriel2Config): # Handle weight tying if configured if config.tie_word_embeddings: - self._tied_weights_keys = ["lm_head.weight"] + self._tied_weights_keys = ( + ["lm_head.weight"] if _TRANSFORMERS_V4 else {"lm_head.weight": "model.embed_tokens.weight"} + ) self.post_init() diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index d0e1988f1..40a6e4af8 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -1,8 +1,10 @@ +import dataclasses from functools import partial from typing import Callable, Optional, Union import torch import torch.utils.checkpoint +import transformers from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache @@ -28,6 +30,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MTPLlamaConfig" +_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) class LlamaRMSNorm(nn.Module): @@ -56,21 +59,38 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): def __init__(self, config: MTPLlamaConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + # Support both transformers 4.x (rope_theta + rope_scaling) and 5.x (rope_parameters) + if hasattr(config, "rope_parameters"): + self.rope_type = config.rope_parameters.get("rope_type", "default") + elif hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = self.compute_default_rope_parameters + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq + @staticmethod + def compute_default_rope_parameters(config, device=None, seq_len=None): + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: @@ -557,11 +577,13 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - # MTP: The last layer is not part of the shared trunk - for decoder_layer in self.layers[:-1]: - if output_hidden_states: - all_hidden_states += (hidden_states,) + # MTP: The last layer is not part of the shared trunk. + # Always add the initial embedding state first, then add the output of each trunk layer. + # This is consistent with how standard transformers models collect hidden states via @capture_outputs. + if output_hidden_states: + all_hidden_states += (hidden_states,) + for decoder_layer in self.layers[:-1]: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( partial(decoder_layer.__call__, **flash_attn_kwargs), @@ -589,6 +611,9 @@ def forward( hidden_states = layer_outputs[0] + if output_hidden_states: + all_hidden_states += (hidden_states,) + if output_attentions: all_self_attns += (layer_outputs[1],) @@ -762,7 +787,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = ["lm_head.weight"] if _TRANSFORMERS_V4 else {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py index 337ff1fa3..d8b12c586 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -27,7 +27,11 @@ import pytest import torch -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Cache, _AttentionCache +from fast_llm_external_models.apriel2.modeling_apriel2 import ( + _TRANSFORMERS_V4, + Apriel2Cache, + _AttentionCache, +) # ============================================================================= # SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer @@ -113,7 +117,9 @@ def test_hf_mask_sizes_kv_length( # Verify HF's kv_length follows the expected formula cache_position = torch.arange(1) # Single token decode - hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] + ) expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0] assert hf_kv_len == expected_kv_len @@ -130,7 +136,9 @@ def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, hf_dynamic_layer.update(key.clone(), value.clone()) cache_position = torch.arange(1) - _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] + ) assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0" @@ -248,7 +256,9 @@ def test_kv_offset_zero_before_window_full( apriel_sliding_cache.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] + ) # Verify HF returns 0 offset before window full assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}" @@ -271,7 +281,9 @@ def test_kv_offset_increases_after_window_full( apriel_sliding_cache.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] + ) # At window boundary, offset should be 1 assert hf_kv_offset == 1, "HF offset should be 1 at window boundary" @@ -284,7 +296,9 @@ def test_kv_offset_increases_after_window_full( hf_sliding_layer.update(key.clone(), value.clone()) apriel_sliding_cache.update(key.clone(), value.clone()) - hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] + ) expected_offset = i + 2 assert hf_kv_offset == expected_offset @@ -306,7 +320,7 @@ def test_kv_length_capped_at_window( apriel_sliding_cache.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position) + hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position if _TRANSFORMERS_V4 else cache_position.shape[0]) # HF returns window (window-1 cached + 1 query) assert hf_kv_len == window_size @@ -443,8 +457,12 @@ def test_get_mask_sizes_matches_dynamic_layer(self, attention_config): hf_layer.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) - apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] + ) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0], layer_idx=0 + ) assert apr_kv_len == hf_kv_len assert apr_kv_offset == hf_kv_offset @@ -464,8 +482,12 @@ def test_get_mask_sizes_matches_sliding_layer(self, swa_config): hf_layer.update(key.clone(), value.clone()) cache_position = torch.arange(1) - hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) - apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0] + ) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes( + cache_position if _TRANSFORMERS_V4 else cache_position.shape[0], layer_idx=0 + ) assert apr_kv_len == hf_kv_len assert apr_kv_offset == hf_kv_offset @@ -508,13 +530,23 @@ def test_full_attention_decode_can_attend_to_all(self): kv_length = cache.cumulative_length + 1 kv_offset = 0 - mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=causal_mask_function, - ) + if _TRANSFORMERS_V4: + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) + else: + mask = sdpa_mask( + batch_size=1, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) if mask is not None: # Query at position 10 should attend to positions 0-10 @@ -540,13 +572,23 @@ def test_sliding_window_decode_respects_window(self, window_size): kv_offset = max(cumulative - window_size + 1, 0) kv_length = window_size - 1 + 1 # cached + query - mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=sliding_window_causal_mask_function(window_size), - ) + if _TRANSFORMERS_V4: + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) + else: + mask = sdpa_mask( + batch_size=1, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) if mask is not None: query_mask = mask[0, 0, 0, :] @@ -573,14 +615,25 @@ def test_prefill_has_causal_pattern(self): kv_length = cache.cumulative_length kv_offset = 0 - mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=causal_mask_function, - allow_is_causal_skip=False, # Force mask creation - ) + if _TRANSFORMERS_V4: + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) + else: + mask = sdpa_mask( + batch_size=1, + q_length=cache_position.shape[0], + q_offset=cache_position[0].item(), + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) if mask is not None: # Check causal pattern diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py index f96f5ac40..545ba7864 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py +++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py @@ -21,7 +21,7 @@ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration +from fast_llm_external_models.apriel2.modeling_apriel2 import _TRANSFORMERS_V4, Apriel2ForConditionalGeneration # ============================================================================= # Config Conversion Tests @@ -129,7 +129,12 @@ def test_plan_weight_values_unchanged(self, llava_pixtral_checkpoint): apriel2_weights = execute(plan, source_weights, seed=0) # Check specific weights are identical - source_embed = source_weights["language_model.model.embed_tokens.weight"] + source_embed_key = ( + "language_model.model.embed_tokens.weight" + if _TRANSFORMERS_V4 + else "model.language_model.model.embed_tokens.weight" + ) + source_embed = source_weights[source_embed_key] target_embed = apriel2_weights["model.embed_tokens.weight"] assert torch.equal(source_embed, target_embed) diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index 8734aa02c..c1a814e00 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -94,7 +94,12 @@ def create_inputs(model: LlavaForConditionalGeneration, config: InputConfig, see dummy_pixel = torch.randn(1, 3, h, w) with torch.no_grad(): features = model.get_image_features(dummy_pixel) - num_patches = features[0].shape[0] if isinstance(features, list) else features.shape[1] + if isinstance(features, list): + num_patches = features[0].shape[0] + elif hasattr(features, "pooler_output") and features.pooler_output is not None: + num_patches = features.pooler_output[0].shape[0] + else: + num_patches = features.shape[1] else: num_patches = 0 @@ -164,7 +169,10 @@ def get_pixtral_vision_features(source: LlavaForConditionalGeneration, pixel_val """Get vision features from Pixtral, flattened to [total_patches, hidden].""" features = source.get_image_features(pixel_values) if isinstance(features, list): - features = torch.cat(features, dim=0) + return torch.cat(features, dim=0) + # 5.x: BaseModelOutput with pooler_output = list of projected feature tensors + if hasattr(features, "pooler_output") and features.pooler_output is not None: + return torch.cat(features.pooler_output, dim=0) return features diff --git a/setup.cfg b/setup.cfg index e035cc0c1..6748b3a6b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.57.3,<5.0.0 + transformers>=4.57.3,<6.0.0 hf-transfer>=0.1.9 datasets>=4.4.1 huggingface-hub>=0.36.0 diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index fba0b4265..ea7099ba1 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -10,24 +10,16 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig -from fast_llm.layers.ssm.gdn import _fast_gdn_available -from fast_llm.layers.ssm.kda import _kda_available from fast_llm.utils import Assert +from fast_llm_external_models.apriel2.modeling_apriel2 import ( + Apriel2GatedDeltaNet, + Apriel2Mamba, + KimiDeltaAttention, + _gdn_fla_available, + _kda_fla_available, +) from tests.utils.utils import get_stage -try: - from fast_llm_external_models.apriel2.modeling_apriel2 import ( - Apriel2GatedDeltaNet, - Apriel2Mamba, - KimiDeltaAttention, - is_fast_path_available, - ) -except ImportError: - Apriel2GatedDeltaNet = None - Apriel2Mamba = None - KimiDeltaAttention = None - is_fast_path_available = False - HIDDEN_SIZE = 16 SEQUENCE_LENGTH = 65 BATCH_SIZE = 2 @@ -98,16 +90,8 @@ def _compare_mixers( @pytest.mark.slow -# Arguments ('seq_idx',) not implemented for torch implementation of 1d convolution. -@pytest.mark.skipif(not is_fast_path_available, reason="GDN deps missing") -@pytest.mark.parametrize( - "use_backup", - [ - pytest.param(False, marks=pytest.mark.skipif(not _fast_gdn_available, reason="FLA not available")), - True, - ], - ids=["fast", "backup"], -) +@pytest.mark.skipif(not _gdn_fla_available, reason="GDN external model not available") +@pytest.mark.parametrize("use_backup", [True, False], ids=["fast", "backup"]) def test_gdn(testing_device, use_backup, monkeypatch): if use_backup: import fast_llm.layers.ssm.gdn as gdn_module @@ -141,15 +125,8 @@ def test_gdn(testing_device, use_backup, monkeypatch): @pytest.mark.slow -@pytest.mark.skipif(KimiDeltaAttention is None, reason="KDA external model not available") -@pytest.mark.parametrize( - "use_backup", - [ - pytest.param(False, marks=pytest.mark.skipif(not _kda_available, reason="KDA fused kernels not available")), - pytest.param(True, marks=pytest.mark.skipif(not _kda_available, reason="KDA fla package not available")), - ], - ids=["fast", "backup"], -) +@pytest.mark.skipif(not _kda_fla_available, reason="KDA external model not available") +@pytest.mark.parametrize("use_backup", [True, False], ids=["fast", "backup"]) def test_kda(testing_device, use_backup, monkeypatch): if use_backup: import fast_llm.layers.ssm.kda as kda_module diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 094cbc094..0b4dbafc1 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -333,6 +333,8 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic device=testing_device, ) kwargs = {"output_hidden_states": True} + if is_mtp := (model_ref.fast_llm_base_model.config.head.prediction_heads > 1): + kwargs["return_all_prediction_heads"] = True if model_testing_config.model_type == "multimodal": kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(testing_device) kwargs["image_sizes"] = torch.tensor( @@ -388,16 +390,23 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic if model_testing_config.model_type == "multimodal" and hasattr(model, "vision_encoder"): kwargs["output_vision_hidden_states"] = True output = model(test_input, **kwargs) - hidden_states = output.hidden_states + (output.logits,) + # Fast-LLM doesn't concatenate the head hidden states. + hidden_states = ( + output.hidden_states[:-1] + output.hidden_states[-1].unbind(-2) if is_mtp else output.hidden_states + ) + (output.logits,) # Llava models doesn't return vision hidden states, so we run the vision model directly instead. if model_testing_config.model_type == "multimodal": - if hasattr(model, "vision_tower"): - vision_output = model.vision_tower( + # transformers v5: LlavaForConditionalGeneration wraps submodules under model.* + vision_model = ( + model.model if hasattr(model, "model") and hasattr(model.model, "vision_tower") else model + ) + if hasattr(vision_model, "vision_tower"): + vision_output = vision_model.vision_tower( pixel_values=kwargs["pixel_values"], image_sizes=kwargs["image_sizes"], output_hidden_states=True, ) - adapter_output = model.multi_modal_projector(vision_output.hidden_states[-1]) + adapter_output = vision_model.multi_modal_projector(vision_output.hidden_states[-1]) hidden_states = vision_output.hidden_states + (adapter_output,) + hidden_states hidden_states_ref_ = hidden_states_ref.copy() # Adjust the vision hidden states diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 0f89d9323..7de1de2da 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -475,7 +475,7 @@ def update_and_add_testing_config( "llama", "mtp_llama", updates={ - ("model", "base_model", "decoder", "num_blocks"): 1, + ("model", "base_model", "decoder", "num_blocks"): 2, ("model", "base_model", "head", "prediction_heads"): 2, }, # Megatron doesn't support multi-token prediction.