diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index a1b600445..4f8595250 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -84,6 +84,7 @@ def get_layer( *, lr_scale: float | None, peft: PeftConfig | None, + **kwargs, ) -> "BlockBase": return self.layer_class( self, @@ -91,6 +92,7 @@ def get_layer( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + **kwargs, ) def get_reference_models(self) -> set[str]: @@ -106,6 +108,10 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return FixedBlockSequenceConfig._from_dict(default, strict) return super()._from_dict(default, strict=strict) + @property + def last_block_config(self) -> BlockConfig: + raise NotImplementedError() + @config_class(dynamic_type={BlockSequenceConfig: "fixed"}) class FixedBlockSequenceConfig(BlockSequenceConfig): @@ -130,6 +136,10 @@ def layer_class(self) -> "type[FixedBlockSequence]": def get_reference_models(self) -> set[str]: return self.block.get_reference_models() + @property + def last_block_config(self) -> BlockConfig: + return self.block + @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) class PatternBlockSequenceConfig(BlockSequenceConfig): @@ -161,6 +171,10 @@ def _validate(self): super()._validate() + @property + def last_block_config(self) -> BlockConfig: + return self.blocks[self.expanded_pattern[-1]] + @property def layer_class(self) -> "type[PatternBlockSequence]": from fast_llm.layers.block.sequence import PatternBlockSequence diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 54a5b3471..2e7425343 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -24,6 +24,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_last_layer_input: bool = False, ): super().__init__( config, @@ -40,8 +41,13 @@ def __init__( hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **( + {"return_input": True} + if return_last_layer_input and block_index == self._config.num_blocks - 1 + else {} + ), ) - for _ in range(self._config.num_blocks) + for block_index in range(self._config.num_blocks) ] ) @@ -75,6 +81,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_last_layer_input: bool = False, ): super().__init__( config, @@ -90,8 +97,13 @@ def __init__( hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **( + {"return_input": True} + if return_last_layer_input and block_index == self._config.num_blocks - 1 + else {} + ), ) - for name in self._config.expanded_pattern + for block_index, name in enumerate(self._config.expanded_pattern) ] ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e3446bba6..0e54e7583 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,4 +1,3 @@ -import abc import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -14,7 +13,7 @@ if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding - from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase + from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction @@ -95,41 +94,8 @@ def layer_class(self) -> "type[LanguageModelEmbedding]": return LanguageModelEmbedding -@config_class(registry=True) -class LanguageModelHeadBaseConfig(BlockConfig): - @classmethod - def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: - if cls is LanguageModelHeadBaseConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return LanguageModelHeadConfig._from_dict(default, strict) - return super()._from_dict(default, strict=strict) - - def get_layer( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - ) -> "LanguageModelHeadBase": - return self.layer_class( - self, - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - peft=peft, - ) - - @property - @abc.abstractmethod - def max_prediction_distance(self) -> int: - pass - - -@config_class(dynamic_type={LanguageModelHeadBaseConfig: "language_model_head"}) -class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): +@config_class() +class LanguageModelHeadConfig(BlockConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the final normalization layer.", @@ -160,6 +126,18 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + prediction_heads: int = Field( + default=1, + desc="Prediction heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def get_layer( self, @@ -169,85 +147,36 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - prediction_distance: int = 0, - prediction_heads: int = 1, - loss_coefficient: float = 1.0, - ): - return self.layer_class( + block_config: DecoderBlockConfig | None = None, + ) -> "tuple[LanguageModelHead, MultiTokenPrediction]": + from fast_llm.layers.language_model.head import LanguageModelHead + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction + + return LanguageModelHead( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + ), MultiTokenPrediction( self, distributed_config, embeddings_config, hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, - prediction_distance=prediction_distance, - prediction_heads=prediction_heads, - loss_coefficient=loss_coefficient, + block_config=block_config, ) - @property - def layer_class(self) -> "type[LanguageModelHead]": - from fast_llm.layers.language_model.head import LanguageModelHead - - return LanguageModelHead - def _validate(self) -> None: super()._validate() assert LM_HEAD_LOSS_NAME not in self.losses - @property - def max_prediction_distance(self) -> int: - return 1 - def get_reference_models(self) -> set[str]: return {reference_model for loss in self.losses.values() for reference_model in loss.get_reference_models()} -@config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) -class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): - _abstract = False - # Needs to be `DecoderBlockConfig` for the `return_input` interface. - # TODO: Make a generic wrapper for returning input instead? - block: DecoderBlockConfig = Field( - desc="Configuration for the decoder block before each head.", - hint=FieldHint.architecture, - ) - # TODO: Generalize? (needs the extra initialization arguments) - head: LanguageModelHeadConfig = Field( - desc="Configuration for the multi-token-prediction heads.", - hint=FieldHint.architecture, - ) - prediction_heads: int = Field( - default=1, - desc="Prediction heads.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - prediction_loss_coefficient: list[float] | None = Field( - default=None, - desc="Loss coefficient for each prediction head.", - doc="If not provided, all heads are equally weighted.", - hint=FieldHint.feature, - ) - - def _validate(self) -> None: - super()._validate() - if isinstance(self.prediction_loss_coefficient, list): - Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) - for coeff in self.prediction_loss_coefficient: - Assert.geq(coeff, 0) - - @property - def layer_class(self) -> "type[MultiTokenPrediction]": - from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction - - return MultiTokenPrediction - - @property - def max_prediction_distance(self) -> int: - return self.prediction_heads - - @config_class() class LanguageModelConfig(BlockConfig): decoder: BlockSequenceConfig = Field( @@ -258,7 +187,7 @@ class LanguageModelConfig(BlockConfig): hint=FieldHint.architecture, desc="Configuration for the language model embeddings.", ) - head: LanguageModelHeadBaseConfig = Field( + head: LanguageModelHeadConfig = Field( hint=FieldHint.architecture, desc="Configuration for the language model head(s)." ) tied_embedding_weight: bool = Field( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c5bf9ff9b..85b9bde1d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,4 +1,3 @@ -import abc import functools import logging import typing @@ -14,12 +13,11 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import Block, BlockBase +from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LM_HEAD_LOSS_NAME, LanguageModelEmbeddingsConfig, - LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, ) @@ -32,15 +30,7 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](BlockBase[ConfigType]): - heads: "list[LanguageModelHead]" - - @abc.abstractmethod - def get_output_weights(self) -> list[torch.Tensor]: - pass - - -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType], Block): +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -58,7 +48,6 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int = 0, - prediction_heads: int = 1, loss_coefficient: float = 1.0, ): super().__init__( @@ -68,11 +57,9 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - Assert.in_range(prediction_distance, 0, prediction_heads) + Assert.in_range(prediction_distance, 0, self._config.prediction_heads) self._prediction_distance = prediction_distance - self._prediction_heads = prediction_heads - self._loss_coefficient = loss_coefficient - self._is_last_head = self._prediction_distance == self._prediction_heads - 1 + self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -99,17 +86,22 @@ def __init__( loss_configs = ( self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()} ) + loss_coefficient = ( + 1.0 + if self._config.prediction_loss_coefficient is None + else self._config.prediction_loss_coefficient[self._prediction_distance] + ) self.losses = torch.nn.ModuleList( [ loss_config.get_layer( distributed_config, self._get_full_loss_name(name), self._prediction_distance, - self._prediction_heads, + self._config.prediction_heads, self._vocab_parallel, self._config.cross_entropy_splits, self._config.logits_scale_factor, - self._loss_coefficient, + loss_coefficient, ) for name, loss_config in loss_configs.items() ] @@ -305,8 +297,3 @@ def _get_full_loss_name(self, name) -> str: @functools.cached_property def _total_loss_name(self) -> str: return self._get_full_loss_name(LM_HEAD_LOSS_NAME) - - @property - def heads(self) -> "list[LanguageModelHead]": - # For compatibility with MTP. - return [self] diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 385bab7ef..32e2ccbf9 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -44,28 +44,42 @@ def __init__( self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **({"return_last_layer_input": True} if self._config.head.prediction_heads > 1 else {}), ) - self.head = self._config.head.get_layer( + self.head, self.multi_token_prediction = self._config.head.get_layer( distributed_config, self._config.embeddings, hidden_dim=self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **( + {"block_config": self._config.decoder.last_block_config} + if self._config.head.prediction_heads > 1 + else {} + ), ) def get_layers(self) -> list[Layer]: - return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() + layers = self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() + if self.multi_token_prediction is not None: + layers += self.multi_token_prediction.get_layers() + return layers def preprocess(self, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(kwargs) self.decoder.preprocess(kwargs) self.head.preprocess(kwargs) + if self.multi_token_prediction is not None: + self.multi_token_prediction.preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? - return ( + losses = ( self.embeddings.get_loss_definitions(count) + self.decoder.get_loss_definitions(count) + self.head.get_loss_definitions(count) ) + if self.multi_token_prediction is not None: + losses += self.multi_token_prediction.get_loss_definitions(count) + return losses diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 537c7996d..e326b9555 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -13,9 +13,6 @@ class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def forward_backward( self, logits: "torch.Tensor", @@ -41,11 +38,6 @@ def forward_backward( class LanguageModelDistillationLoss[ConfigType: LanguageModelDistillationLossConfig](LanguageModelLoss[ConfigType]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self._prediction_distance > 0: - raise NotImplementedError() - def forward_backward( self, logits: "torch.Tensor", diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 8dc88c4a1..f1f65ac39 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -102,7 +102,6 @@ def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): return None if loss_mask is None else self._prepare_target(loss_mask, kwargs, split_index) def _get_reference_model_logits(self, reference_model: str, kwargs: dict[str, typing.Any], split_index: int = 0): - assert self._prediction_distance == 0 Assert.incl( logits_name := self.module_name.rsplit(".", 2)[0] + f".logits", reference_hidden_states := kwargs[f"reference_{reference_model}_hidden_states"], diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index c606e2d68..720592c41 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -10,12 +10,6 @@ class LanguageModelZLoss[ConfigType: LanguageModelZLossConfig](LanguageModelLoss[ConfigType]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO: Support vocab_parallel - if self._vocab_parallel: - raise NotImplementedError() - def forward_backward( self, logits: "torch.Tensor", diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index 5efe2d836..d7665cf00 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -7,12 +7,14 @@ from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, MultiTokenPredictionConfig -from fast_llm.layers.language_model.head import LanguageModelHeadBase +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelHeadConfig +from fast_llm.layers.language_model.head import LanguageModelHead -class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](LanguageModelHeadBase[ConfigType]): +class MultiTokenPrediction[ConfigType: LanguageModelHeadConfig](BlockBase[ConfigType]): _config: ConfigType def __init__( @@ -24,6 +26,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + block_config: DecoderBlockConfig | None = None, ): super().__init__( config, @@ -32,9 +35,12 @@ def __init__( lr_scale=lr_scale, peft=peft, ) + self._enabled = self._config.prediction_heads > 1 + if self._enabled: + assert block_config is not None self.blocks = torch.nn.ModuleList( [ - self._config.block.get_layer( + block_config.get_layer( self._distributed_config, self._hidden_dim, lr_scale=self._lr_scale, @@ -43,26 +49,21 @@ def __init__( # The previous blocks return a stack of shared_hidden and transformer_output. return_input=index < self._config.prediction_heads - 1, ) - for index in range(self._config.prediction_heads) + for index in range(1, self._config.prediction_heads) ] ) self.heads = torch.nn.ModuleList( [ - self._config.head.get_layer( + LanguageModelHead( + self._config, distributed_config, embeddings_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, prediction_distance=index, - prediction_heads=self._config.prediction_heads, - loss_coefficient=( - 1.0 - if self._config.prediction_loss_coefficient is None - else self._config.prediction_loss_coefficient[index] - ), ) - for index in range(self._config.prediction_heads) + for index in range(1, self._config.prediction_heads) ] ) @@ -70,8 +71,11 @@ def __init__( def _layers_with_namespace(self) -> list[Layer]: # Wrap all blocks in a namespace using the unique module name of the first one. # This needs to be in a property because `module_name` is set after `__init__`. - namespace = self.blocks[0].module_name - return [LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers()] + return [ + LayerWithNamespace(sublayer, self.blocks[0].module_name) + for layer in self.blocks + for sublayer in layer.get_layers() + ] def get_layers(self) -> list[Layer]: return [ @@ -84,9 +88,13 @@ def get_output_weights(self) -> list[torch.Tensor]: return sum((head.get_output_weights() for head in self.heads), []) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - self._layers_with_namespace[0].preprocess(kwargs) + if self._enabled: + self._layers_with_namespace[0].preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.blocks[0].get_loss_definitions(count=count * self._config.prediction_heads) + [ - loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count) - ] + return ( + self.blocks[0].get_loss_definitions(count=count * (self._config.prediction_heads - 1)) + + [loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count)] + if self._enabled + else [] + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 314741c3b..ddcbcf696 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -168,8 +168,8 @@ def _validate(self) -> None: for reference_model in self.reference_models.values(): Assert.geq( - reference_model.model.base_model.head.max_prediction_distance, - self.model.base_model.head.max_prediction_distance, + reference_model.model.base_model.head.prediction_heads, + self.model.base_model.head.prediction_heads, ) Assert.empty(reference_model.model.base_model.get_reference_models()) Assert.eq( diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 00d871dbf..983df9869 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -488,16 +488,15 @@ def get_converters( cls, config: LanguageModelHeadConfig, exported_config: dict, - fast_llm_prefix: str, ) -> list[WeightConverter]: return [ *cls.normalization_converter_class.get_converters( config.normalization, - f"{fast_llm_prefix}.final_norm", + f"head.final_norm", f"model.norm", ), get_parameter_converter( - f"{fast_llm_prefix}.output_weights", + f"head.output_weights", "lm_head.weight", drop_on_import=exported_config["tie_word_embeddings"], drop_on_export=exported_config["tie_word_embeddings"], @@ -539,7 +538,7 @@ def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> li return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), - *cls.head_converter_class.get_converters(config.head, exported_config, "head"), + *cls.head_converter_class.get_converters(config.head, exported_config), ] diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 5b83fed69..0c58b7be5 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -5,16 +5,14 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import WeightConverter from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, MultiTokenPredictionConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, - LlamaBlockConverter, LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, - get_parameter_converter, ) from fast_llm.utils import Assert, safe_merge_dicts @@ -23,17 +21,14 @@ class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod def import_config(cls, config: dict) -> dict: return { - "type": "multi_token_prediction", - "block": LlamaBlockConverter.import_config(config), - "head": super().import_config(config), + **super().import_config(config), "prediction_heads": config["prediction_heads"], } @classmethod - def export_config(cls, config: MultiTokenPredictionConfig) -> dict: - Assert.custom(isinstance, config, MultiTokenPredictionConfig) + def export_config(cls, config: LanguageModelHeadConfig) -> dict: return safe_merge_dicts( - super().export_config(config.head), + super().export_config(config), {"prediction_heads": config.prediction_heads}, ) @@ -42,33 +37,15 @@ def get_converters( cls, config: LanguageModelHeadConfig, exported_config: dict, - fast_llm_prefix: str, ) -> list[WeightConverter]: - converters = [] - for prediction_distance in range(config.prediction_heads): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.blocks.{prediction_distance}", - ( - f"model.layers.{exported_config["num_hidden_layers"]-1}" - if prediction_distance == 0 - else f"model.mtp_heads.{prediction_distance - 1}" - ), - ) - converters += cls.normalization_converter_class.get_converters( + return super().get_converters(config, exported_config) + [ + cls.normalization_converter_class.get_converters( config.head.normalization, - f"{fast_llm_prefix}.heads.{prediction_distance}.final_norm", + f"multi_token_prediction.heads.{prediction_distance - 1}.final_norm", f"model.mtp_norms.{prediction_distance}", ) - converters.append( - get_parameter_converter( - f"{fast_llm_prefix}.heads.0.output_weights", - "lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], - ) - ) - - return converters + for prediction_distance in range(1, config.prediction_heads) + ] class MTPLlamaDecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index cabcdc489..698f624ed 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -250,7 +250,7 @@ def preprocess_batch( kwargs[BlockKwargs.output_hidden_states].append(re.compile(r"head\..*logits.*$")) else: labels_begin = tokens_begin + 1 - labels_end = tokens_end + self._config.head.max_prediction_distance + labels_end = tokens_end + self._config.head.prediction_heads labels = batch.tokens.crop(labels_begin, labels_end).tokens if batch.loss_masking_spans is not None: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index ded0f81c8..ef4956176 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -25,7 +25,7 @@ def _get_sampling_parameters( { "sequence_length": self._config.batch.sequence_length, "truncate_documents": self._config.batch.truncate_documents, - "extra_tokens": self._config.model.base_model.head.max_prediction_distance, + "extra_tokens": self._config.model.base_model.head.prediction_heads, } ) return parameters if _return_dict else SamplingParameters(**parameters) diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 8703ef920..a75d732b8 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -258,16 +258,15 @@ def get_converters( cls, config: LanguageModelHeadConfig, exported_config: dict, - fast_llm_prefix: str, ) -> list[WeightConverter]: return [ *cls.normalization_converter_class.get_converters( config.normalization, - f"{fast_llm_prefix}.final_norm", + f"head.final_norm", f"language_model.model.norm", ), get_parameter_converter( - f"{fast_llm_prefix}.output_weights", + f"head.output_weights", "language_model.lm_head.weight", drop_on_import=exported_config["tie_word_embeddings"], ), @@ -320,7 +319,7 @@ def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict config.decoder, "decoder", "language_model.model.layers" ), *cls.language_model_converter_class.head_converter_class.get_converters( - config.head, {"tie_word_embeddings": False}, "head" + config.head, {"tie_word_embeddings": False} ), ] diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index b1a922099..a8ae85c12 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -47,6 +47,7 @@ def get_config(self) -> GPTModelConfig: "normalization": {"type": "rms_norm"}, "logits_scale_factor": self.logits_scale_factor, "cross_entropy_splits": self.num_splits, + "prediction_heads": self.prediction_heads, } losses = {} if self.label_loss is not False: @@ -69,15 +70,7 @@ def get_config(self) -> GPTModelConfig: "base_model": { "decoder": {"num_blocks": 0}, "embeddings": {"vocab_size": VOCAB_SIZE, "full_precision_residual": self.full_precision_residual}, - "head": ( - head_config - if self.prediction_heads == 1 - else { - "type": "multi_token_prediction", - "head": head_config, - "prediction_heads": self.prediction_heads, - } - ), + "head": head_config, "hidden_size": HIDDEN_SIZE, "tied_embedding_weight": self.tied_embedding_weight, }, @@ -246,8 +239,9 @@ def test_lm_head(test_config: LMHeadTestConfig): else None ) - for prediction_distance, head in enumerate(model.head.heads): + for prediction_distance in range(model_config.base_model.head.prediction_heads): # Prepare the LM head + head = model.head if prediction_distance == 0 else model.multi_token_prediction.heads[prediction_distance - 1] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) is_duplicate = test_config.tied_embedding_weight or prediction_distance > 0 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 7b41c1f50..40dbb7d29 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -470,13 +470,8 @@ def update_and_add_testing_config( "llama", "mtp_llama", updates={ - ("model", "base_model", "head"): { - "type": "multi_token_prediction", - "block": _llama_block, - "head": MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["head"], - "prediction_heads": 2, - }, ("model", "base_model", "decoder", "num_blocks"): 1, + ("model", "base_model", "head", "prediction_heads"): 1, }, # Megatron doesn't support multi-token prediction. megatron_args=None,