Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ def get_layer(
*,
lr_scale: float | None,
peft: PeftConfig | None,
**kwargs,
) -> "BlockBase":
return self.layer_class(
self,
distributed_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
**kwargs,
)

def get_reference_models(self) -> set[str]:
Expand All @@ -106,6 +108,10 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return FixedBlockSequenceConfig._from_dict(default, strict)
return super()._from_dict(default, strict=strict)

@property
def last_block_config(self) -> BlockConfig:
raise NotImplementedError()


@config_class(dynamic_type={BlockSequenceConfig: "fixed"})
class FixedBlockSequenceConfig(BlockSequenceConfig):
Expand All @@ -130,6 +136,10 @@ def layer_class(self) -> "type[FixedBlockSequence]":
def get_reference_models(self) -> set[str]:
return self.block.get_reference_models()

@property
def last_block_config(self) -> BlockConfig:
return self.block


@config_class(dynamic_type={BlockSequenceConfig: "pattern"})
class PatternBlockSequenceConfig(BlockSequenceConfig):
Expand Down Expand Up @@ -161,6 +171,10 @@ def _validate(self):

super()._validate()

@property
def last_block_config(self) -> BlockConfig:
return self.blocks[self.expanded_pattern[-1]]

@property
def layer_class(self) -> "type[PatternBlockSequence]":
from fast_llm.layers.block.sequence import PatternBlockSequence
Expand Down
16 changes: 14 additions & 2 deletions fast_llm/layers/block/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
return_last_layer_input: bool = False,
):
super().__init__(
config,
Expand All @@ -40,8 +41,13 @@ def __init__(
hidden_dim,
lr_scale=self._lr_scale,
peft=self._peft,
**(
{"return_input": True}
if return_last_layer_input and block_index == self._config.num_blocks - 1
else {}
),
)
for _ in range(self._config.num_blocks)
for block_index in range(self._config.num_blocks)
]
)

Expand Down Expand Up @@ -75,6 +81,7 @@ def __init__(
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
return_last_layer_input: bool = False,
):
super().__init__(
config,
Expand All @@ -90,8 +97,13 @@ def __init__(
hidden_dim,
lr_scale=self._lr_scale,
peft=self._peft,
**(
{"return_input": True}
if return_last_layer_input and block_index == self._config.num_blocks - 1
else {}
),
)
for name in self._config.expanded_pattern
for block_index, name in enumerate(self._config.expanded_pattern)
]
)

Expand Down
131 changes: 30 additions & 101 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
Expand All @@ -14,7 +13,7 @@

if typing.TYPE_CHECKING:
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding
from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.language_model.language_model import LanguageModel
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction

Expand Down Expand Up @@ -95,41 +94,8 @@ def layer_class(self) -> "type[LanguageModelEmbedding]":
return LanguageModelEmbedding


@config_class(registry=True)
class LanguageModelHeadBaseConfig(BlockConfig):
@classmethod
def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
if cls is LanguageModelHeadBaseConfig and cls.get_subclass(default.get("type")) is None:
# Default subclass.
return LanguageModelHeadConfig._from_dict(default, strict)
return super()._from_dict(default, strict=strict)

def get_layer(
self,
distributed_config: DistributedConfig,
embeddings_config: LanguageModelEmbeddingsConfig,
*,
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
) -> "LanguageModelHeadBase":
return self.layer_class(
self,
distributed_config,
embeddings_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
)

@property
@abc.abstractmethod
def max_prediction_distance(self) -> int:
pass


@config_class(dynamic_type={LanguageModelHeadBaseConfig: "language_model_head"})
class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
@config_class()
class LanguageModelHeadConfig(BlockConfig):
_abstract = False
normalization: NormalizationConfig = Field(
desc="Configuration for the final normalization layer.",
Expand Down Expand Up @@ -160,6 +126,18 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
prediction_loss_coefficient: list[float] | None = Field(
default=None,
desc="Loss coefficient for each prediction head.",
doc="If not provided, all heads are equally weighted.",
hint=FieldHint.feature,
)

def get_layer(
self,
Expand All @@ -169,85 +147,36 @@ def get_layer(
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
prediction_distance: int = 0,
prediction_heads: int = 1,
loss_coefficient: float = 1.0,
):
return self.layer_class(
block_config: DecoderBlockConfig | None = None,
) -> "tuple[LanguageModelHead, MultiTokenPrediction]":
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction

return LanguageModelHead(
self,
distributed_config,
embeddings_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
), MultiTokenPrediction(
self,
distributed_config,
embeddings_config,
hidden_dim=hidden_dim,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
peft=peft,
prediction_distance=prediction_distance,
prediction_heads=prediction_heads,
loss_coefficient=loss_coefficient,
block_config=block_config,
)

@property
def layer_class(self) -> "type[LanguageModelHead]":
from fast_llm.layers.language_model.head import LanguageModelHead

return LanguageModelHead

def _validate(self) -> None:
super()._validate()
assert LM_HEAD_LOSS_NAME not in self.losses

@property
def max_prediction_distance(self) -> int:
return 1

def get_reference_models(self) -> set[str]:
return {reference_model for loss in self.losses.values() for reference_model in loss.get_reference_models()}


@config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"})
class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig):
_abstract = False
# Needs to be `DecoderBlockConfig` for the `return_input` interface.
# TODO: Make a generic wrapper for returning input instead?
block: DecoderBlockConfig = Field(
desc="Configuration for the decoder block before each head.",
hint=FieldHint.architecture,
)
# TODO: Generalize? (needs the extra initialization arguments)
head: LanguageModelHeadConfig = Field(
desc="Configuration for the multi-token-prediction heads.",
hint=FieldHint.architecture,
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
prediction_loss_coefficient: list[float] | None = Field(
default=None,
desc="Loss coefficient for each prediction head.",
doc="If not provided, all heads are equally weighted.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
super()._validate()
if isinstance(self.prediction_loss_coefficient, list):
Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads)
for coeff in self.prediction_loss_coefficient:
Assert.geq(coeff, 0)

@property
def layer_class(self) -> "type[MultiTokenPrediction]":
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction

return MultiTokenPrediction

@property
def max_prediction_distance(self) -> int:
return self.prediction_heads


@config_class()
class LanguageModelConfig(BlockConfig):
decoder: BlockSequenceConfig = Field(
Expand All @@ -258,7 +187,7 @@ class LanguageModelConfig(BlockConfig):
hint=FieldHint.architecture,
desc="Configuration for the language model embeddings.",
)
head: LanguageModelHeadBaseConfig = Field(
head: LanguageModelHeadConfig = Field(
hint=FieldHint.architecture, desc="Configuration for the language model head(s)."
)
tied_embedding_weight: bool = Field(
Expand Down
35 changes: 11 additions & 24 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import functools
import logging
import typing
Expand All @@ -14,12 +13,11 @@
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward
from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward
from fast_llm.layers.block.block import Block, BlockBase
from fast_llm.layers.block.block import Block
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.language_model.config import (
LM_HEAD_LOSS_NAME,
LanguageModelEmbeddingsConfig,
LanguageModelHeadBaseConfig,
LanguageModelHeadConfig,
LanguageModelKwargs,
)
Expand All @@ -32,15 +30,7 @@
OUTPUT_WEIGHTS = "output_weights"


class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](BlockBase[ConfigType]):
heads: "list[LanguageModelHead]"

@abc.abstractmethod
def get_output_weights(self) -> list[torch.Tensor]:
pass


class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType], Block):
class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]):
"""
A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable).
TODO: Cleanup (dynamic type? composition?)
Expand All @@ -58,7 +48,6 @@ def __init__(
lr_scale: float | None,
peft: PeftConfig | None,
prediction_distance: int = 0,
prediction_heads: int = 1,
loss_coefficient: float = 1.0,
):
super().__init__(
Expand All @@ -68,11 +57,9 @@ def __init__(
lr_scale=lr_scale,
peft=peft,
)
Assert.in_range(prediction_distance, 0, prediction_heads)
Assert.in_range(prediction_distance, 0, self._config.prediction_heads)
self._prediction_distance = prediction_distance
self._prediction_heads = prediction_heads
self._loss_coefficient = loss_coefficient
self._is_last_head = self._prediction_distance == self._prediction_heads - 1
self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1

self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel
self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor)
Expand All @@ -99,17 +86,22 @@ def __init__(
loss_configs = (
self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()}
)
loss_coefficient = (
1.0
if self._config.prediction_loss_coefficient is None
else self._config.prediction_loss_coefficient[self._prediction_distance]
)
self.losses = torch.nn.ModuleList(
[
loss_config.get_layer(
distributed_config,
self._get_full_loss_name(name),
self._prediction_distance,
self._prediction_heads,
self._config.prediction_heads,
self._vocab_parallel,
self._config.cross_entropy_splits,
self._config.logits_scale_factor,
self._loss_coefficient,
loss_coefficient,
)
for name, loss_config in loss_configs.items()
]
Expand Down Expand Up @@ -305,8 +297,3 @@ def _get_full_loss_name(self, name) -> str:
@functools.cached_property
def _total_loss_name(self) -> str:
return self._get_full_loss_name(LM_HEAD_LOSS_NAME)

@property
def heads(self) -> "list[LanguageModelHead]":
# For compatibility with MTP.
return [self]
Loading