Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions fast_llm/engine/inference/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import dataclasses
import logging
import os
import pathlib
Expand All @@ -12,20 +13,32 @@

logger = logging.getLogger(__name__)

_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig)


class HuggingfaceModelConfig(transformers.PretrainedConfig):
model_type = "fast_llm"
model_config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig
fast_llm_config: FastLLMModelConfig | None = None
use_cache: bool = True

def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs):
def __post_init__(self, **kwargs):
# Needed for `to_diff_dict` (`__repr__`)
if fast_llm_config is None:
fast_llm_config = self.model_config_class()
self.fast_llm_config = fast_llm_config
self.use_cache = kwargs.pop("use_cache", True)
super().__init__(**kwargs)
if self.torch_dtype is not None:
assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch
if self.fast_llm_config is None:
self.fast_llm_config = self.model_config_class()
super().__post_init__(**kwargs)
if self.dtype is not None:
assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch

if _TRANSFORMERS_V4:

def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs):
# Needed for `to_diff_dict` (`__repr__`)
self.fast_llm_config = fast_llm_config if fast_llm_config is not None else self.model_config_class()
self.use_cache = kwargs.pop("use_cache", True)
super().__init__(**kwargs)
if self.torch_dtype is not None:
assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch

def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None:
# Hack the method to save at the right place.
Expand Down Expand Up @@ -88,9 +101,9 @@ def _get_config_dict(
)
metadata = cls.model_config_class.load_metadata(pretrained)
updates = {}
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
updates[("distributed", "compute_dtype")] = torch_dtype
dtype = kwargs.pop("dtype", None) or kwargs.pop("torch_dtype", None) # torch_dtype: transformers v4
if dtype is not None:
updates[("distributed", "compute_dtype")] = dtype
fast_llm_config = cls.model_config_class.from_metadata(
pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates
)
Expand Down
12 changes: 9 additions & 3 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.inference.config import HuggingfaceModelConfig
from fast_llm.engine.inference.config import _TRANSFORMERS_V4, HuggingfaceModelConfig
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.config import StageMode
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.utils import Assert

if _TRANSFORMERS_V4:
from transformers.modeling_utils import no_init_weights as transformers_no_init_weights
else:
from transformers.initialization import no_init_weights as transformers_no_init_weights


logger = logging.getLogger(__name__)


Expand All @@ -38,7 +44,7 @@ def __init__(
**kwargs,
):
if config is None:
config = self.config_class(fast_llm_model.config)
config = self.config_class(fast_llm_config=fast_llm_model.config)

assert self.runner_class.model_class.config_class is config.model_config_class
assert config.fast_llm_config is fast_llm_model.config
Expand Down Expand Up @@ -70,7 +76,7 @@ def __init__(
# Transformers needs to be able to inspect the base model.
self.fast_llm_base_model = fast_llm_model.base_model

with transformers.modeling_utils.no_init_weights():
with transformers_no_init_weights():
self.post_init()

if fast_llm_model.config.multi_stage.zero_stage == 3:
Expand Down
98 changes: 58 additions & 40 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import logging
import typing

import torch
import transformers

from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.checkpoint.external import (
Expand Down Expand Up @@ -30,6 +32,8 @@
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert, div, safe_merge_dicts

_TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -188,32 +192,37 @@ def import_weight(
class LlamaAttentionConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
try:
rope_type = config["rope_scaling"]["rope_type"]
except (KeyError, TypeError):
rope_type = "default"
rotary_config = {
"type": rope_type,
"theta": config["rope_theta"],
}
# Normalize rope params to a single dict before dispatching on rope_type.
# transformers v5 consolidates rope_theta + rope_scaling into rope_parameters.
# transformers v4: rope_theta at top level, rope_scaling dict for non-default types.
# Note: detection is on checkpoint format, not transformers version — old checkpoints
# remain loadable with v5 transformers.
if "rope_parameters" in config: # transformers v5
rope_params = config["rope_parameters"]
rope_theta = rope_params["rope_theta"]
else: # transformers v4
rope_params = config.get("rope_scaling") or {}
rope_theta = config["rope_theta"]
rope_type = rope_params.get("rope_type", "default")
rotary_config = {"type": rope_type, "theta": rope_theta}
if rope_type == "default":
pass
elif rope_type == "llama3":
rotary_config.update(
{
"scale_factor": config["rope_scaling"]["factor"],
"low_frequency_factor": config["rope_scaling"]["low_freq_factor"],
"high_frequency_factor": config["rope_scaling"]["high_freq_factor"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
"scale_factor": rope_params["factor"],
"low_frequency_factor": rope_params["low_freq_factor"],
"high_frequency_factor": rope_params["high_freq_factor"],
"original_context_length": rope_params["original_max_position_embeddings"],
}
)
elif rope_type == "yarn":
rotary_config.update(
{
"attention_factor": config["rope_scaling"]["attention_factor"],
"beta_fast": config["rope_scaling"]["beta_fast"],
"beta_slow": config["rope_scaling"]["beta_slow"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
"attention_factor": rope_params["attention_factor"],
"beta_fast": rope_params["beta_fast"],
"beta_slow": rope_params["beta_slow"],
"original_context_length": rope_params["original_max_position_embeddings"],
}
)
else:
Expand All @@ -235,36 +244,45 @@ def import_config(cls, config: dict) -> dict:
def export_config(cls, config: AttentionConfig) -> dict:
cls._check_config(config)
Assert.eq(config.softmax_scale_power, 0.5)
out = {
"num_attention_heads": config.heads,
"num_key_value_heads": config.head_groups,
"head_dim": config.head_size,
"attention_bias": config.add_linear_biases,
"attention_dropout": config.dropout,
"rope_theta": config.rotary.theta,
}
rope_parameters = {"rope_theta": config.rotary.theta}
if type(config.rotary) is DefaultRotaryConfig:
pass
rope_parameters["rope_type"] = "default"
elif type(config.rotary) is Llama3RotaryConfig:
out["rope_scaling"] = {
"rope_type": "llama3",
"factor": config.rotary.scale_factor,
"low_freq_factor": config.rotary.low_frequency_factor,
"high_freq_factor": config.rotary.high_frequency_factor,
"original_max_position_embeddings": config.rotary.original_context_length,
}
rope_parameters.update(
{
"rope_type": "llama3",
"factor": config.rotary.scale_factor,
"low_freq_factor": config.rotary.low_frequency_factor,
"high_freq_factor": config.rotary.high_frequency_factor,
"original_max_position_embeddings": config.rotary.original_context_length,
}
)
elif type(config.rotary) is YarnRotaryConfig:
out["rope_scaling"] = {
"rope_type": "yarn",
"attention_factor": config.rotary.attention_factor,
"beta_fast": config.rotary.beta_fast,
"beta_slow": config.rotary.beta_slow,
"original_max_position_embeddings": config.rotary.original_context_length,
}
rope_parameters.update(
{
"rope_type": "yarn",
"attention_factor": config.rotary.attention_factor,
"beta_fast": config.rotary.beta_fast,
"beta_slow": config.rotary.beta_slow,
"original_max_position_embeddings": config.rotary.original_context_length,
}
)
else:
raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}")

return out
common = {
"num_attention_heads": config.heads,
"num_key_value_heads": config.head_groups,
"head_dim": config.head_size,
"attention_bias": config.add_linear_biases,
"attention_dropout": config.dropout,
}
if _TRANSFORMERS_V4:
out = {**common, "rope_theta": rope_parameters["rope_theta"]}
if type(config.rotary) is not DefaultRotaryConfig:
out["rope_scaling"] = {k: v for k, v in rope_parameters.items() if k != "rope_theta"}
return out
return {**common, "rope_parameters": rope_parameters}

@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/models/gpt/conversion/mtp_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_converters(
converters += cls.block_converter_class.get_converters(
config.decoder.last_block_config,
f"multi_token_prediction.blocks.{prediction_distance-2}",
f"model.mtp_heads.{prediction_distance - 1}",
f"model.mtp_heads.{prediction_distance - 2}",
)
converters += cls.normalization_converter_class.get_converters(
config.head.normalization,
Expand All @@ -73,7 +73,7 @@ class MTPLlamaDecoderConverter(LlamaDecoderConverter):
def import_config(cls, config: dict) -> dict:
return {
"block": cls.block_converter_class.import_config(config),
"num_blocks": config["num_hidden_layers"] - 1,
"num_blocks": config["num_hidden_layers"],
}

@classmethod
Expand All @@ -82,7 +82,7 @@ def export_config(cls, config: FixedBlockSequenceConfig) -> dict:
Assert.custom(isinstance, config, FixedBlockSequenceConfig)
return safe_merge_dicts(
cls.block_converter_class.export_config(config.block),
{"num_hidden_layers": config.num_blocks + 1},
{"num_hidden_layers": config.num_blocks},
)


Expand Down
31 changes: 28 additions & 3 deletions fast_llm/models/gpt/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
class HuggingfaceGPTModelConfig(HuggingfaceModelConfig):
model_type = "fast_llm_gpt"
model_config_class = GPTModelConfig
fast_llm_config: GPTModelConfig

# transformers v5: PretrainedConfig is a dataclass, so redefining a field in a subclass
# would create a new dataclass field with a different default. Guard with TYPE_CHECKING
# so type checkers see the narrowed type without affecting the runtime dataclass layout.
if typing.TYPE_CHECKING:
fast_llm_config: GPTModelConfig


class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel):
Expand All @@ -41,6 +46,7 @@ def inner_forward(
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
return_all_prediction_heads: bool = False,
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
return self._inner_forward(
self._get_batch(input_ids, attention_mask),
Expand All @@ -52,6 +58,7 @@ def inner_forward(
output_attentions,
output_hidden_states,
return_dict,
return_all_prediction_heads,
)

def _get_batch(
Expand Down Expand Up @@ -89,6 +96,7 @@ def _inner_forward(
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
return_all_prediction_heads: bool = False,
) -> transformers.modeling_outputs.CausalLMOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -122,7 +130,18 @@ def _inner_forward(
}

# TODO: Handle MTP.
logits = hidden_states.pop("head.logits")

self.fast_llm_base_model.head.module_name
logits = hidden_states.pop(f"{self.fast_llm_base_model.head.module_name}.logits")
if return_all_prediction_heads:
logits = torch.stack(
[logits]
+ [
hidden_states.pop(f"{head.module_name}.logits")
for head in self.fast_llm_base_model.multi_token_prediction.heads
],
dim=-2,
)

output = transformers.modeling_outputs.CausalLMOutputWithPast(
logits=logits,
Expand Down Expand Up @@ -151,7 +170,13 @@ def _get_input(
output_hidden_states = (
[self.fast_llm_base_model.embeddings.module_name + "$"]
+ [layer.module_name + "$" for layer in self.fast_llm_base_model.decoder][:-1]
+ [self.fast_llm_base_model.head.final_norm.module_name + "$"]
+ [
head.final_norm.module_name + "$"
for head in [
self.fast_llm_base_model.head,
*self.fast_llm_base_model.multi_token_prediction.heads,
]
]
)

# This needs to be set before preprocessing so it propagates to layers with namespace.
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ def preprocess_batch(
Assert.empty(kwargs.keys() & extra_kwargs.keys())
kwargs.update(extra_kwargs)
if phase == PhaseType.inference:
kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$"))
kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"(?:.*\.)?logits.*$"))

if not model_input.is_meta:
for name, reference_model in self._reference_models.items():
output_hidden_states = set()
if name in self._head_reference_models:
output_hidden_states.add(re.compile(r"head\..*logits.*$"))
output_hidden_states.add(re.compile(r"(?:.*\.)?logits.*$"))
if name in self._decoder_reference_models:
# TODO: Get the actual names
output_hidden_states.add(re.compile(r"decoder\.\d+\.mixer_output$"))
output_hidden_states.add(re.compile(r"(?:.*\.)?decoder\.\d+\.mixer_output$"))
assert len(output_hidden_states) >= 1
reference_model_input = dataclasses.replace(
model_input,
Expand Down
7 changes: 6 additions & 1 deletion fast_llm/models/multimodal/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
class HuggingfaceMultiModalModelConfig(HuggingfaceGPTModelConfig):
model_type = "fast_llm_multi_modal"
model_config_class = MultiModalModelConfig
fast_llm_config: MultiModalModelConfig

# transformers v5: PretrainedConfig is a dataclass, so redefining a field in a subclass
# would create a new dataclass field with a different default. Guard with TYPE_CHECKING
# so type checkers see the narrowed type without affecting the runtime dataclass layout.
if typing.TYPE_CHECKING:
fast_llm_config: MultiModalModelConfig


class HuggingfaceMultiModalModelForCausalLM(HuggingfaceGPTModelForCausalLM):
Expand Down
6 changes: 4 additions & 2 deletions fast_llm_external_models/apriel2/conversion/llava/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def convert_config(llava_config: dict) -> dict:
hidden_size = text_config["hidden_size"]
num_heads = text_config["num_attention_heads"]
num_kv_heads = text_config["num_key_value_heads"]
rope_theta = text_config["rope_theta"]
# transformers v4 checkpoint: rope_theta at top level; v5: inside rope_parameters
rope_theta = text_config.get("rope_theta") or text_config.get("rope_parameters", {}).get("rope_theta", 10000.0)
# Use explicit head_dim if available (some models have head_dim != hidden_size // num_heads)
# Note: MistralConfig.head_dim is None by default, so we must check for None explicitly
head_dim = text_config.get("head_dim")
Expand Down Expand Up @@ -98,7 +99,8 @@ def _convert_vision_config(llava_config: dict) -> dict:
num_heads = vision_config["num_attention_heads"]
num_layers = vision_config["num_hidden_layers"]
intermediate_size = vision_config["intermediate_size"]
rope_theta = vision_config["rope_theta"]
# transformers v4 checkpoint: rope_theta at top level; v5: inside rope_parameters
rope_theta = vision_config.get("rope_theta") or vision_config.get("rope_parameters", {}).get("rope_theta", 10000.0)
patch_size = vision_config["patch_size"]
num_channels = vision_config["num_channels"]
# Use explicit head_dim if available
Expand Down
Loading
Loading