Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ training:
evaluator:
type: loss
iterations: null
test_iters: 0
batch:
sequence_length: 4096
micro_batch_size: 2
Expand Down
10 changes: 8 additions & 2 deletions fast_llm/data/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
BlendedDatasetConfig,
ConcatenatedDatasetConfig,
DatasetSliceConfig,
MemmapDatasetConfig,
SampledDatasetUpdateConfig,
)
from fast_llm.data.dataset.memmap.config import ( # isort: skip
LanguageModelReaderConfig,
MemmapDatasetConfig,
NullReaderConfig,
PatchReaderConfig,
RangeReaderConfig,
TokenReaderConfig,
)
from fast_llm.data.dataset.gpt.config import ( # isort: skip
GPTDatasetFromFileConfig,
GPTFimSampledDatasetConfig,
GPTRandomDatasetConfig,
)
from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig # isort: skip
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
from fast_llm.data.sample.abstract import NullReaderConfig # isort: skip
File renamed without changes.
124 changes: 124 additions & 0 deletions fast_llm/data/batch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import abc
import dataclasses
import functools
import logging
import typing

from fast_llm.config import Configurable, Field, FieldUpdate, config_class
from fast_llm.data.document.abstract import Batch, Document
from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig
from fast_llm.data.preprocessing.image_patch import ImagePatchConfig
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.preprocessing.tokenizer import TokenizerConfig
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.schedule.config import BatchConfig
from fast_llm.models.gpt.config import GPTBatchConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
import torch

logger = logging.getLogger(__name__)


@config_class()
class BatchPreprocessingConfig(PreprocessingConfig):
batch: BatchConfig = Field()
phase: PhaseType = Field(default=PhaseType.inference)

def get_batch_meta(self) -> "PreprocessedBatch":
raise NotImplementedError()


@config_class()
class LanguageModelBatchPreprocessingConfig(LanguageModelPreprocessingConfig, BatchPreprocessingConfig):
_abstract = False
# TODO: Duplicate `use_loss_masking_spans`, `use_preference_spans`
batch: GPTBatchConfig = FieldUpdate()
predicted_tokens: int = Field(default=1)
return_cumulative_sequence_lengths: bool = Field(default=False)
return_max_sequence_lengths: bool = Field(default=False)
return_document_index: bool = Field(default=False)
return_position_index: bool = Field(default=False)
return_prediction_mask: bool = Field(default=False)

def _validate(self) -> None:
super()._validate()
Assert.custom(isinstance, self.image_patches, (ImagePatchConfig, NullPreprocessingConfig))
Assert.custom(isinstance, self.tokenizer, (TokenizerConfig, NullPreprocessingConfig))

def get_batch_meta(self) -> "PreprocessedBatch":
import torch

from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch
from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument
from fast_llm.data.document.token import TokenDocument

device = torch.device("meta")
tokens = torch.empty(self.total_length, dtype=torch.int64, device=device)
batch = LanguageModelBatch.from_documents([LanguageModelDocument(tokens=TokenDocument(tokens=tokens))])
return LanguageModelPreprocessedBatch.from_batch(batch, config=self, device=device)

@functools.cached_property
def use_image_patches(self) -> bool:
return isinstance(self.image_patches, ImagePatchConfig)

@functools.cached_property
def total_length(self) -> int:
return self.batch.sequence_length + self.predicted_tokens

@functools.cached_property
def distributed(self) -> DistributedConfig:
return self.batch.distributed

def check_compatibility(self, preprocessing: typing.Self) -> None:
Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig)
# TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub?
if self.vocab_size is not None and preprocessing.vocab_size is not None:
Assert.leq(self.vocab_size, preprocessing.vocab_size)
if preprocessing.use_preference_spans:
# Preference spans are strictly needed for DPO loss.
assert self.use_preference_spans, "The dataset is missing required preference spans"
if preprocessing.use_image_patches and self.use_image_patches:
self.image_patches.check_compatibility(preprocessing.image_patches)


@dataclasses.dataclass
class MicroBatch:
pass


class PreprocessedBatch[ConfigType: BatchPreprocessingConfig, MicroBatchType: MicroBatch](Configurable[ConfigType]):
def __init__(self, config: ConfigType, micro_batches: list[MicroBatchType]):
super().__init__(config)
self._micro_batches = micro_batches

@property
def micro_batches(self) -> list[MicroBatchType]:
return self._micro_batches

def __len__(self) -> int:
return len(self._micro_batches)

def __getitem__(self, idx: int) -> MicroBatchType:
return self._micro_batches[idx]

@classmethod
@abc.abstractmethod
def from_documents(
cls,
documents: list[Document],
config: BatchPreprocessingConfig,
device: "torch.device | None" = None,
) -> typing.Self:
pass

@classmethod
@abc.abstractmethod
def from_batch(
cls,
batch: Batch,
config: BatchPreprocessingConfig,
device: "torch.device | None" = None,
) -> typing.Self:
pass
159 changes: 159 additions & 0 deletions fast_llm/data/batch/language_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import dataclasses
import typing

import torch

from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig, MicroBatch, PreprocessedBatch
from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedDimNames
from fast_llm.tensor import TensorMeta


@dataclasses.dataclass
class LanguageModelMicroBatch(MicroBatch):
tokens: torch.Tensor
token_dim: TensorDim
hidden_token_dim: TensorDim
sequence_k_dim: TensorDim
# TODO: Adjust names
num_tokens: int # Number of tokens in the micro-batch excluding padding at the end.
sequence_length: int # Total number of tokens across all micro-batches, including padding.
document_lengths: list[int]
is_meta: bool
labels: list[torch.Tensor] = dataclasses.field(default_factory=list)
prediction_masks: list[torch.Tensor] = dataclasses.field(default_factory=list)
cumulative_lengths_q: torch.Tensor | None = None
cumulative_lengths_k: torch.Tensor | None = None
max_length_q: torch.Tensor | None = None
max_length_k: torch.Tensor | None = None
document_index: torch.Tensor | None = None
position_index: torch.Tensor | None = None
# TODO: ====== Preference spans? ======

def to_device_(self, device: torch.device):
self.tokens = self.tokens.to(device, non_blocking=True)
if self.cumulative_lengths_q is not None:
self.cumulative_lengths_q = self.cumulative_lengths_q.to(device, non_blocking=True)
if self.cumulative_lengths_k is not None:
self.cumulative_lengths_k = self.cumulative_lengths_k.to(device, non_blocking=True)
if self.max_length_q is not None:
self.max_length_q = self.max_length_q.to(device, non_blocking=True)
if self.max_length_k is not None:
self.max_length_k = self.max_length_k.to(device, non_blocking=True)
if self.document_index is not None:
self.document_index = self.document_index.to(device, non_blocking=True)
if self.position_index is not None:
self.position_index = self.position_index.to(device, non_blocking=True)


@dataclasses.dataclass
class LanguageModelPreprocessedBatch[
ConfigType: LanguageModelBatchPreprocessingConfig, MicroBatchType: LanguageModelMicroBatch
](PreprocessedBatch[ConfigType, MicroBatchType]):
def __init__(self, config: LanguageModelBatchPreprocessingConfig, micro_batches: list[MicroBatchType]):
super().__init__(config, micro_batches)

@classmethod
def from_documents(
cls,
documents: list[LanguageModelDocument],
config: ConfigType,
device: torch.device | None = None,
) -> typing.Self:
batch = LanguageModelBatch.from_documents(
documents, pad_to_size=config.batch.micro_batch_size * config.total_length
)
return cls.from_batch(batch, config=config, device=device)

@classmethod
def from_batch(
cls,
batch: LanguageModelBatch,
config: ConfigType,
device: torch.device | None = None,
) -> typing.Self:
if device is None:
device = batch.tokens.tokens.device
batch.to_device_(device)
is_meta = device.type == "meta"

token_dim = TensorDim(
"token",
config.batch.micro_sequence_length,
config.distributed.get_distributed_dim(DistributedDimNames.sequence_data),
)
hidden_token_dim = (
(
"token_tp",
token_dim.global_size,
config.distributed.get_distributed_dim(DistributedDimNames.tensor_and_data),
)
if config.distributed.sequence_tensor_parallel
else token_dim
)
micro_batches = []
for micro_sequence_index, sequence_k_past in enumerate(
range(
token_dim.size * config.distributed.sequence_data_rank,
config.batch.sequence_length,
token_dim.global_size,
)
):
sequence_k = sequence_k_past + token_dim.size
sequence_k_dim = TensorDim("sequence_k", sequence_k)
cropped_sample = batch.crop(sequence_k_past, sequence_k)
if is_meta:
tokens = TensorMeta.from_dims(
(token_dim,), tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64
)
else:
tokens = batch.tokens.tokens[sequence_k_past:sequence_k]
micro_batch = LanguageModelMicroBatch(
tokens=tokens,
token_dim=token_dim,
hidden_token_dim=hidden_token_dim,
sequence_k_dim=sequence_k_dim,
num_tokens=min(sequence_k, batch.num_tokens) - sequence_k_past,
sequence_length=config.batch.sequence_length,
document_lengths=batch.tokens.lengths,
is_meta=is_meta,
)
if not is_meta:
if config.return_cumulative_sequence_lengths:
micro_batch.cumulative_lengths_q, micro_batch.cumulative_lengths_k = (
cropped_sample.tokens.get_cumulative_lengths(device)
)
if config.return_max_sequence_lengths:
micro_batch.max_length_q, micro_batch.max_length_k = cropped_sample.tokens.get_max_lengths(device)
if config.return_document_index:
micro_batch.document_index = cropped_sample.tokens.get_document_index()
if config.return_position_index:
micro_batch.position_index = cropped_sample.tokens.get_position_index()

for prediction_distance in range(1, config.predicted_tokens + 1):
label_begin = sequence_k_past + prediction_distance
label_end = sequence_k + prediction_distance
label_tokens = batch.tokens.crop(label_begin, label_end)
labels = label_tokens.tokens.clone()

# Apply loss masking spans.
if config.use_loss_masking_spans and batch.loss_masking_spans is not None:
for span_begin, span_end in batch.loss_masking_spans.crop(label_begin, label_end).ranges:
labels[span_begin:span_end] = -100

# Mask cross-document predictions.
document_begin = label_tokens.lengths[0]
for length in label_tokens.lengths[1:]:
labels[document_begin : document_begin + prediction_distance] = -100
document_begin += length

# Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions.
micro_batch.labels.append(labels)
if config.return_prediction_mask:
# TODO: Does the prediction mask really need all sources of masking?
# (i.e. lack of labels doesn't mean we can't do predictions and compute other losses.)
micro_batch.prediction_masks.append(labels > 0)

micro_batches.append(micro_batch)
return cls(micro_batches=micro_batches, config=config)
36 changes: 16 additions & 20 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import typing

from fast_llm.config import Configurable
from fast_llm.data.batch.config import BatchPreprocessingConfig, PreprocessedBatch
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import Batch
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.schedule.config import BatchConfig

Expand All @@ -16,31 +14,28 @@

class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC):
_distributed: "Distributed"
_sampling_parameters: dict[str, SamplingParameters]
_preprocessing: PreprocessingConfig
# _sampling_parameters: dict[str, SamplingParameters]
# _preprocessing: dict[str, PreprocessingConfig]
_cache_directory: pathlib.Path | None
_is_setup: bool = False

def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None:
super().__init__(config)
self._distributed_config = distributed_config

# TODO: Improve interface
def setup(
self,
distributed: "Distributed",
sampling_parameters: dict[str, SamplingParameters],
preprocessing: PreprocessingConfig,
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._sampling_parameters = sampling_parameters
self._preprocessing = preprocessing
def setup(self, cache_directory: pathlib.Path) -> None:
self._cache_directory = cache_directory
self._is_setup = True

@property
def distributed(self):
return self._distributed
@abc.abstractmethod
def sample_dataset(
self,
dataset_name: str,
config: BatchPreprocessingConfig,
num_samples: int,
) -> None:
pass

@abc.abstractmethod
def get_iterator(
Expand All @@ -52,5 +47,6 @@ def get_iterator(
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[Batch]:
preprocess: bool = True,
) -> typing.Iterator[PreprocessedBatch]:
pass
4 changes: 2 additions & 2 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.document.language_model import LanguageModelDocument
logger = logging.getLogger(__name__)


Expand All @@ -23,7 +23,7 @@ class GPTDataConfig(DataConfig):
_abstract = False

# TODO: Review field. Move closer to phase definition in training config?
datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field(
datasets: dict[str, SampledDatasetConfig["LanguageModelDocument"]] = Field(
default_factory=dict,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
Expand Down
Loading