From 7469f83eb0d26875e92ffaa52791ee2a20dd9698 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 17 Feb 2026 16:50:38 -0500 Subject: [PATCH 1/4] stuff --- fast_llm/batch/__init__.py | 0 fast_llm/batch/config.py | 54 ++++ fast_llm/batch/language_model.py | 144 +++++++++++ fast_llm/data/data/abstract.py | 6 +- fast_llm/data/data/gpt/data.py | 5 +- fast_llm/data/preprocessing/language_model.py | 2 +- fast_llm/data/sample/abstract.py | 3 + fast_llm/data/sample/language_model.py | 11 + fast_llm/data/sample/patch.py | 5 + fast_llm/data/sample/token.py | 62 ++++- fast_llm/engine/base_model/base_model.py | 9 + fast_llm/engine/training/trainer.py | 31 ++- fast_llm/layers/attention/attention.py | 190 ++++++-------- fast_llm/layers/attention/config.py | 6 - fast_llm/layers/block/config.py | 1 + fast_llm/layers/block/sequence.py | 12 +- fast_llm/layers/decoder/block.py | 7 +- fast_llm/layers/language_model/config.py | 7 - fast_llm/layers/language_model/embedding.py | 20 +- .../layers/language_model/language_model.py | 11 +- .../language_model/multi_token_prediction.py | 6 +- fast_llm/layers/ssm/gdn.py | 12 +- fast_llm/layers/ssm/kda.py | 12 +- fast_llm/layers/ssm/mamba.py | 14 +- fast_llm/layers/vision/vision_encoder.py | 17 +- fast_llm/models/gpt/model.py | 234 +++--------------- fast_llm/models/gpt/trainer.py | 14 +- .../models/multimodal/conversion/apriel2.py | 3 +- .../models/multimodal/conversion/llava.py | 2 - fast_llm/models/multimodal/model.py | 3 +- tests/layers/test_attention.py | 4 +- tests/layers/test_lm_head.py | 3 - tests/layers/test_varlen.py | 4 +- tests/utils/model_configs.py | 3 - 34 files changed, 502 insertions(+), 415 deletions(-) create mode 100644 fast_llm/batch/__init__.py create mode 100644 fast_llm/batch/config.py create mode 100644 fast_llm/batch/language_model.py diff --git a/fast_llm/batch/__init__.py b/fast_llm/batch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/batch/config.py b/fast_llm/batch/config.py new file mode 100644 index 000000000..f857d115b --- /dev/null +++ b/fast_llm/batch/config.py @@ -0,0 +1,54 @@ +import functools +import logging +import typing + +from fast_llm.config import Field, FieldUpdate, config_class +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig +from fast_llm.data.preprocessing.image_patch import ImagePatchConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +@config_class(registry=True) +class BatchPreprocessingConfig(PreprocessingConfig): + batch: BatchConfig = Field() + + +@config_class(dynamic_type={PreprocessingConfig: "language_model"}) +class LanguageModelBatchPreprocessingConfig(LanguageModelPreprocessingConfig): + _abstract = False + # TODO: Duplicate `use_loss_masking_spans`, `use_preference_spans` + batch: GPTBatchConfig = FieldUpdate() + phase: PhaseType = Field(default=PhaseType.inference) + predicted_tokens: int = Field(default=1) + return_cumulative_sequence_lengths: bool = Field(default=False) + return_max_sequence_lengths: bool = Field(default=False) + return_document_index: bool = Field(default=False) + return_position_index: bool = Field(default=False) + return_prediction_mask: bool = Field(default=False) + + def _validate(self) -> None: + super()._validate() + Assert.custom(isinstance, self.image_patches, (ImagePatchConfig, NullPreprocessingConfig)) + Assert.custom(isinstance, self.tokenizer, (TokenizerConfig, NullPreprocessingConfig)) + + @functools.cached_property + def use_image_patches(self) -> bool: + return isinstance(self.image_patches, ImagePatchConfig) + + def check_compatibility(self, preprocessing: typing.Self) -> None: + Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) + # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? + if self.vocab_size is not None and preprocessing.vocab_size is not None: + Assert.leq(self.vocab_size, preprocessing.vocab_size) + if preprocessing.use_preference_spans: + # Preference spans are strictly needed for DPO loss. + assert self.use_preference_spans, "The dataset is missing required preference spans" + if preprocessing.use_image_patches and self.use_image_patches: + self.image_patches.check_compatibility(preprocessing.image_patches) diff --git a/fast_llm/batch/language_model.py b/fast_llm/batch/language_model.py new file mode 100644 index 000000000..7de5c07e3 --- /dev/null +++ b/fast_llm/batch/language_model.py @@ -0,0 +1,144 @@ +import dataclasses +import typing + +import torch + +from fast_llm.batch.config import LanguageModelBatchPreprocessingConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames + + +@dataclasses.dataclass +class LanguageModelBatchNew: + tokens: torch.Tensor + token_dim: TensorDim + hidden_token_dim: TensorDim + sequence_k_dim: TensorDim + # TODO: Adjust names + num_tokens: int # Number of tokens in the micro-batch excluding padding at the end. + sequence_length: int # Total number of tokens across all micro-batches, including padding. + document_lengths: list[int] + labels: list[torch.Tensor] = dataclasses.field(default_factory=list) + prediction_masks: list[torch.Tensor] = dataclasses.field(default_factory=list) + cumulative_lengths_q: torch.Tensor | None = None + cumulative_lengths_k: torch.Tensor | None = None + max_length_q: torch.Tensor | None = None + max_length_k: torch.Tensor | None = None + document_index: torch.Tensor | None = None + position_index: torch.Tensor | None = None + chosen_spans: list[tuple[int, int]] | None = None + rejected_spans: list[tuple[int, int]] | None = None + + def to_device_(self, device: torch.device): + self.tokens = self.tokens.to(device, non_blocking=True) + if self.cumulative_lengths_q is not None: + self.cumulative_lengths_q = self.cumulative_lengths_q.to(device, non_blocking=True) + if self.cumulative_lengths_k is not None: + self.cumulative_lengths_k = self.cumulative_lengths_k.to(device, non_blocking=True) + if self.max_length_q is not None: + self.max_length_q = self.max_length_q.to(device, non_blocking=True) + if self.max_length_k is not None: + self.max_length_k = self.max_length_k.to(device, non_blocking=True) + if self.document_index is not None: + self.document_index = self.document_index.to(device, non_blocking=True) + if self.position_index is not None: + self.position_index = self.position_index.to(device, non_blocking=True) + + @classmethod + def from_documents( + cls, + config: LanguageModelBatchPreprocessingConfig, + distributed_config: DistributedConfig, + documents: list[LanguageModelSample], + device: torch.device | None = None, + ) -> list[typing.Self]: + num_tokens = sum(len(document) for document in documents) + padding = config.batch.sequence_length + config.predicted_tokens - num_tokens + sample = LanguageModelSample.from_documents(documents + [documents[0].get_padding(padding)]) + # sample.tokens.lengths + # lengths = [len(document) for document in documents] + # num_tokens = sum(lengths) + + if device is None: + device = sample.tokens.tokens.device + sample.to_device_(device) + + token_dim = TensorDim( + "token", + config.batch.micro_sequence_length, + distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), + ) + hidden_token_dim = ( + ( + "token_tp", + token_dim.global_size, + distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), + ) + if distributed_config.sequence_tensor_parallel + else token_dim + ) + micro_batches = [] + for micro_sequence_index, sequence_k_past in enumerate( + range( + token_dim.size * distributed_config.sequence_data_rank, + config.batch.sequence_length, + token_dim.global_size, + ) + ): + sequence_k = sequence_k_past + token_dim.size + sequence_k_dim = TensorDim("sequence_k", sequence_k) + cropped_sample = sample.crop(sequence_k_past, sequence_k) + + # document_lengths, cumulative_lengths_q, cumulative_lengths_k, first_document_index, remaining_tokens = crop_lengths( + # sample.tokens.lengths, sequence_k_past, sequence_k_past + token_dim.size) + + micro_batch = LanguageModelBatchNew( + tokens=sample.tokens.tokens[sequence_k_past:sequence_k], + token_dim=token_dim, + hidden_token_dim=hidden_token_dim, + sequence_k_dim=sequence_k_dim, + num_tokens=min(sequence_k, num_tokens) - sequence_k_past, + sequence_length=config.batch.sequence_length, + document_lengths=sample.tokens.lengths, + ) + if config.return_cumulative_sequence_lengths: + micro_batch.cumulative_lengths_q, micro_batch.cumulative_lengths_k = ( + cropped_sample.tokens.get_cumulative_lengths(device) + ) + if config.return_max_sequence_lengths: + micro_batch.max_length_q, micro_batch.max_length_k = cropped_sample.tokens.get_max_lengths(device) + if config.return_document_index: + micro_batch.document_index = cropped_sample.tokens.get_document_index() + if config.return_position_index: + micro_batch.position_index = cropped_sample.tokens.get_position_index() + if config.use_preference_spans: + micro_batch.chosen_spans = cropped_sample.chosen_spans.ranges + micro_batch.rejected_spans = cropped_sample.rejected_spans.ranges + + for prediction_distance in range(1, config.predicted_tokens + 1): + label_begin = sequence_k_past + prediction_distance + label_end = sequence_k + prediction_distance + label_tokens = sample.tokens.crop(label_begin, label_end) + labels = label_tokens.tokens.clone() + + # Apply loss masking spans. + if config.use_loss_masking_spans: + for span_begin, span_end in sample.loss_masking_spans.crop(label_begin, label_end).ranges: + labels[span_begin:span_end] = -100 + + # Mask cross-document predictions. + document_end = 0 + for length in label_tokens.lengths: + document_end += length + labels[max(document_end - prediction_distance, 0) : document_end] = -100 + + # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. + micro_batch.labels.append(labels) + if config.return_prediction_mask: + # TODO: Does the prediction mask really need all sources of masking? + # (i.e. lack of labels doesn't mean we can't do predictions and compute other losses.) + micro_batch.prediction_masks.append(labels > 0) + + micro_batches.append(micro_batch) + return micro_batches diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index 2c1902796..e01331be2 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -9,6 +9,7 @@ from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.distributed.distributed import Distributed @@ -17,7 +18,7 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): _distributed: "Distributed" _sampling_parameters: dict[str, SamplingParameters] - _preprocessing: PreprocessingConfig + _preprocessing: dict[str, PreprocessingConfig] _cache_directory: pathlib.Path | None def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: @@ -29,10 +30,11 @@ def setup( self, distributed: "Distributed", sampling_parameters: dict[str, SamplingParameters], - preprocessing: PreprocessingConfig, + preprocessing: dict[str, PreprocessingConfig], cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: + Assert.eq(sampling_parameters.keys(), preprocessing.keys()) self._distributed = distributed self._sampling_parameters = sampling_parameters self._preprocessing = preprocessing diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 17f151919..3a1e99e6d 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,6 +32,7 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, SamplingParameters] + _preprocessing: dict[str, LanguageModelPreprocessingConfig] _is_setup: bool = False def __init__( @@ -49,7 +50,7 @@ def setup( self, distributed: "Distributed", sampling_parameters: dict[str, SamplingParameters], - preprocessing: LanguageModelPreprocessingConfig, + preprocessing: dict[str, LanguageModelPreprocessingConfig], cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: @@ -84,7 +85,7 @@ def setup( sampling = GPTSamplingData( config=self._config.sampling, parameters=sampling_parameters, - preprocessing=preprocessing, + preprocessing=self._preprocessing[dataset_name], cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index d54776eec..b4f1a69a7 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -20,7 +20,7 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig): # so we provide the vocab size and use it for compatibility checks. image_patches: PreprocessingConfig = Field() vocab_size: int | None = Field(default=None) - use_loss_masking_spans: bool = Field(default=False) + use_loss_masking_spans: bool = Field(default=True) use_preference_spans: bool = Field(default=False) def _validate(self) -> None: diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 494a5c4a5..c5dcf165e 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -29,6 +29,9 @@ def __len__(self) -> int: def get_padding(self, size: int) -> typing.Self: pass + def to_device_(self, device: "torch.device | str"): + pass + class Batch(abc.ABC): # TODO: Relate to `BatchConfig`? diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 22b89acf1..db7e89d87 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -91,6 +91,17 @@ def get_padding(self, size: int) -> typing.Self: None if self.image_patches is None else self.image_patches.get_padding(size), ) + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) + if self.image_patches is not None: + self.image_patches.to_device_(device) + class LanguageModelBatch(Batch): def __init__( diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 32ea60cb8..0be91f0c8 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -93,6 +93,11 @@ def get_padding(self, size: int) -> typing.Self: [], ) + def to_device_(self, device: "torch.device | str"): + self.patches = self.patches.to(device, non_blocking=True) + self.token_map = self.token_map.to(device, non_blocking=True) + self.positions = self.positions.to(device, non_blocking=True) + class PatchBatch(Batch): def __init__( diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 6ab55dbba..17078cef9 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -14,7 +14,7 @@ Sample, ) from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import Assert, get_unique +from fast_llm.utils import Assert, get_unique, padded_cumsum def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: @@ -35,7 +35,13 @@ def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: class TokenSample(Sample): - def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): + def __init__( + self, + tokens: torch.Tensor, + lengths: list[int] | None = None, + sequence_k_past: int = 0, + current_document_begin: int = 0, + ): self.tokens = tokens # Length of each document in the sample. TODO: Use cumsums instead? if lengths is None: @@ -43,6 +49,8 @@ def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): else: Assert.eq(sum(lengths), len(tokens)) self.lengths = lengths + self.sequence_k_past = sequence_k_past + self.current_document_begin = current_document_begin @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: @@ -52,7 +60,23 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: ) def crop(self, begin: int, end: int) -> typing.Self: - return self.__class__(self.tokens[begin:end], crop_lengths(self.lengths, begin, end)) + Assert.eq(self.sequence_k_past, self.current_document_begin, 0) + + document_begin = 0 + lengths_ = [] + current_document_begin = None + for length in self.lengths: + document_end = document_begin + length + cropped_length = min(document_end, end) - max(document_begin, begin) + if cropped_length > 0: + lengths_.append(cropped_length) + if not current_document_begin: + current_document_begin = document_begin + if document_end > end: + break + document_begin = document_end + + return self.__class__(self.tokens[begin:end], lengths_, begin, current_document_begin) def __len__(self) -> int: return len(self.tokens) @@ -60,6 +84,38 @@ def __len__(self) -> int: def get_padding(self, size: int) -> typing.Self: return self.__class__(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + def to_device_(self, device: "torch.device | str"): + # Also standardize the dtype while we're here. + self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) + + def get_cumulative_lengths(self, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: + cumulative_lengths_q = torch.from_numpy(padded_cumsum(self.lengths)).to(dtype=torch.int32, device=device) + cumulative_lengths_k = torch.cat( + [self.current_document_begin, cumulative_lengths_q[1:] + self.sequence_k_past] + ) + return cumulative_lengths_q, cumulative_lengths_k + + def get_max_lengths(self, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: + max_length_q = max(self.lengths) + max_length_k = max(self.max_length_q, self.sequence_k_past + self.lengths[0] - self.current_document_begin) + return ( + torch.full((1,), max_length_q, dtype=torch.int32, device=device), + torch.full((1,), max_length_k, dtype=torch.int32, device=device), + ) + + def get_document_index(self, device: torch.device | None = None) -> torch.Tensor: + return torch.cat( + [ + torch.full((document_length,), i, dtype=torch.int32, device=device) + for i, document_length in enumerate(self.lengths) + ] + ) + + def get_position_index(self, device: torch.device | None = None) -> torch.Tensor: + return torch.cat( + [torch.arange(document_length, dtype=torch.int32, device=device) for document_length in self.lengths] + ) + class TokenBatch(Batch): def __init__(self, tokens: torch.Tensor, lengths: list[list[int]] | None) -> None: diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index de64d905a..f5f8dc5e7 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -9,6 +9,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.utils import safe_merge_dicts if typing.TYPE_CHECKING: from fast_llm.engine.inference.runner import InferenceRunner @@ -53,6 +54,11 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: losses += layer.get_loss_definitions(count) return losses + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return safe_merge_dicts( + *(layer.get_preprocessing_config(phase) for layer in self.get_layers() if layer is not self) + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: for layer in self.get_layers(): if layer is not self: @@ -107,6 +113,9 @@ def get_layers(self) -> list["Layer"]: """ return self._layers_with_namespace + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return self._layer.get_preprocessing_config(phase) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: """ Preprocess with namespace. diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index b35733cc7..68c73bf70 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -229,20 +229,23 @@ def setup(self, distributed: Distributed, run: Run) -> None: self._runner.setup(distributed, self._optimizer) # Setup the datasets. log_main_rank("Preparing datasets...") + sampling_parameters = {} + preprocessing_configs = {} + for phase, datasets in self._samples_per_split.items(): + for dataset_name, samples in datasets.items(): + sampling_parameters[dataset_name] = self._get_sampling_parameters({"num_samples": samples}) + preprocessing_configs[dataset_name] = self._get_preprocessing_config(phase) + for eval_sampling_params in self._evaluator_runner.get_sampling_parameters(): + sampling_parameters[eval_sampling_params.dataset_name] = self._get_sampling_parameters( + {"num_samples": eval_sampling_params.num_samples} + ) + preprocessing_configs[eval_sampling_params.dataset_name] = self._get_preprocessing_config( + PhaseType.inference + ) self._data.setup( distributed, - { - dataset_name: self._get_sampling_parameters({"num_samples": samples}) - for datasets in self._samples_per_split.values() - for dataset_name, samples in datasets.items() - } - | { - eval_sampling_params.dataset_name: self._get_sampling_parameters( - {"num_samples": eval_sampling_params.num_samples} - ) - for eval_sampling_params in self._evaluator_runner.get_sampling_parameters() - }, - self._get_preprocessing_config(), + sampling_parameters, + preprocessing_configs, None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", timeout=self._config.training.timeout, ) @@ -269,7 +272,9 @@ def _get_sampling_parameters( ) -> SamplingParameters | dict[str, typing.Any]: return parameters if _return_dict else SamplingParameters(**parameters) - def _get_preprocessing_config(self, *, _return_dict: bool = False) -> PreprocessingConfig | dict[str, typing.Any]: + def _get_preprocessing_config( + self, phase: PhaseType, *, _return_dict: bool = False + ) -> PreprocessingConfig | dict[str, typing.Any]: return {} if _return_dict else NullPreprocessingConfig() @property diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 859bafea2..0eaae34f7 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -8,7 +8,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.attention.preprocessing import preprocess_for_varlen @@ -172,30 +172,28 @@ def __init__( def _attn_backup( self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + query: torch.Tensor, # sq, head_per_group * head_group, head_size + key: torch.Tensor, # sk, head_group, head_size + value: torch.Tensor, # sk, head_group, head_size kwargs: dict[str, typing.Any], - ) -> torch.Tensor: + ) -> torch.Tensor: # sq, head_per_group * head_group, head_size # Backup attention (inefficient) - b, sq, _, _ = query.shape - sk = key.size(1) - - if self._local_head_groups == 1: - query = query.view(b, sq * self._local_heads, self._config.head_size) - key = key.flatten(-2).transpose(-1, -2) - value = value.flatten(-2) - else: - query = ( - query.unflatten(2, (self._local_head_groups, self._local_heads_per_group)) - .transpose(1, 2) - .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._config.head_size) - ) - key = key.movedim(1, 3).flatten(0, 1) - value = value.transpose(1, 2).flatten(0, 1) + sq = query.size(0) + sk = key.size(0) + + # sq, head_per_group * head_group, head_size -> head_group, sq * head_per_group, head_size + query = ( + query.unflatten(1, (self._local_head_groups, self._local_heads_per_group)) + .transpose(0, 1) + .view(self._local_head_groups, sq * self._local_heads_per_group, self._config.head_size) + ) + # sk, head_group, head_size -> head_group, head_size, sk + key = key.movedim(0, 2) + # sk, head_group, head_size -> head_group, sk, head_size + value = value.transpose(0, 1) attn_weights = torch.empty( - (b * self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype + (self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype ) attn_weights = torch.baddbmm( attn_weights, @@ -203,7 +201,7 @@ def _attn_backup( key, beta=0, alpha=self._softmax_scale, - ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) + ).view(self._local_head_groups, sq, self._local_heads_per_group, sk) attn_weights = attn_weights.to(torch.float32) if (attention_mask := kwargs[AttentionKwargs.attention_mask]) is not None: @@ -212,51 +210,33 @@ def _attn_backup( attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( - attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk).to(value.dtype), value + attn_weights.view(self._local_head_groups, sq * self._local_heads_per_group, sk).to(value.dtype), value + ) + # head_group, sq * head_per_group, head_size -> sq, head_per_group * head_group, head_size + return ( + attn_output.view(self._local_head_groups, sq, self._local_heads_per_group, self._config.head_size) + .transpose(0, 1) + .flatten(1, 2) ) - - if self._local_head_groups == 1: - return attn_output.view(b, sq, -1) - else: - return ( - attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.head_size) - .transpose(1, 2) - .flatten(2) - ) def _attn_flash( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kwargs: dict[str, typing.Any] ) -> torch.Tensor: assert _flash_available window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) - if self._config.cross_document_attention: - return _flash_attn_func( - query, - key, - value, - window_size=window_size, - dropout_p=self._config.dropout if self.training else 0.0, - causal=self._config.causal, - softmax_scale=self._softmax_scale, - ).flatten(-2) - else: - return ( - _flash_attn_varlen_func( - query.view(-1, query.size(-2), query.size(-1)), - key.view(-1, key.size(-2), key.size(-1)), - value.view(-1, value.size(-2), value.size(-1)), - kwargs[AttentionKwargs.cu_seqlens_q], - kwargs[AttentionKwargs.cu_seqlens_k], - kwargs[AttentionKwargs.max_seqlen_q], - kwargs[AttentionKwargs.max_seqlen_k], - dropout_p=self._config.dropout if self.training else 0.0, - window_size=window_size, - causal=self._config.causal, - softmax_scale=self._softmax_scale, - ) - .view(query.size()) - .flatten(-2) - ) + _flash_attn_varlen_func( + query, + key, + value, + kwargs[AttentionKwargs.cu_seqlens_q], + kwargs[AttentionKwargs.cu_seqlens_k], + kwargs[AttentionKwargs.max_seqlen_q], + kwargs[AttentionKwargs.max_seqlen_k], + dropout_p=self._config.dropout if self.training else 0.0, + window_size=window_size, + causal=self._config.causal, + softmax_scale=self._softmax_scale, + ) def _query_key_value_forward( self, input_: torch.Tensor @@ -320,17 +300,10 @@ def _forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: query, key_value = self._query_key_value(input_) - # Separate the batch and sequence dimensions - token_dims = (kwargs[AttentionKwargs.batch_dim], kwargs[AttentionKwargs.sequence_q_dim]) - token_shape = tuple(dim.size for dim in token_dims) - query = query.unflatten(0, token_shape) - key_value = key_value.unflatten(0, token_shape) - - # TODO: Move the rest to function. - + # TODO: These get unnecessarily big with lots of small documents. if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: # Clear the lists so tensors can be de-allocated - key_value = torch.cat((past_key_values.pop(0), key_value), dim=1) + key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences @@ -342,12 +315,14 @@ def _forward( key_value = key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1) - query = query.view(*query.shape[:2], self._local_heads, self._config.head_size) - key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) - value = value.view(*value.shape[:2], self._local_head_groups, self._config.head_size) + query = query.unflatten(-1, (self._local_heads, self._config.head_size)) + key = key.unflatten(-1, (self._local_head_groups, self._config.head_size)) + value = value.unflatten(-1, (self._local_head_groups, self._config.head_size)) - self._debug(query, "query_rotary_input", token_dims + self._query_dims, kwargs) - self._debug(key, "key_rotary_input", token_dims + self._kv_dims, kwargs) + self._debug( + query, "query_rotary_input", (token_dim := kwargs[AttentionKwargs.token_dim], *self._query_dims), kwargs + ) + self._debug(key, "key_rotary_input", (token_dim, *self._kv_dims), kwargs) query, key = self._rotary(query, key, kwargs) with set_generator(self._distributed.tp_generator): @@ -359,28 +334,36 @@ def _forward( else: raise NotImplementedError(self._implementation) - self._debug(query, "query", token_dims + self._query_dims, kwargs) - self._debug(key, "key", token_dims + self._kv_dims, kwargs) - self._debug(value, "value", token_dims + self._kv_dims, kwargs) - self._debug(input_, "context", token_dims + (self._dense_dim,), kwargs) + self._debug(query, "query", (token_dim, *self._query_dims), kwargs) + self._debug(key, "key", (token_dim, *self._kv_dims), kwargs) + self._debug(value, "value", (token_dim, *self._kv_dims), kwargs) + self._debug(input_, "context", (token_dim, self._dense_dim), kwargs) - out, bias = self.dense(input_.flatten(0, 1)) - self._debug(out, None, token_dims + (self._hidden_dim,), kwargs) + out, bias = self.dense(input_.flatten(1)) + self._debug( + out, + None, + ( + token_dim, + self._hidden_dim, + ), + kwargs, + ) return out, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - batch_dim: TensorDim = kwargs[AttentionKwargs.batch_dim] + # TODO: ====== Account for varlen ======= sequence_q_dim: TensorDim = kwargs[AttentionKwargs.sequence_q_dim] sequence_k_dim: TensorDim = kwargs[AttentionKwargs.sequence_k_dim] if config.global_: - batch_size, sequence_q = batch_dim.global_size, sequence_q_dim.global_size + sequence_q = sequence_q_dim.global_size # In case of sequence-data-parallel, we need to undo the shift in k-sequence-length. sequence_k = sequence_k_dim.global_size - sequence_q_dim.size * ( sequence_q_dim.parallel_dim.size - sequence_q_dim.parallel_dim.rank - 1 ) else: - batch_size, sequence_q = batch_dim.size, sequence_q_dim.size + sequence_q = sequence_q_dim.size sequence_k = sequence_k_dim.size # 2 for multiply and accumulate, 2 operations (Q * K, attn * V), double for backward + Q * K recomputation. @@ -422,12 +405,17 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) ) + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + out = {} + if self._implementation == AttentionImplementation.flash: + out["return_cumulative_sequence_lengths"] = True + out["return_max_sequence_lengths"] = True + return out + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) if self._implementation == AttentionImplementation.backup: self._preprocess_for_backup_attention(kwargs) - elif self._implementation == AttentionImplementation.flash: - self._preprocess_for_flash_attention(kwargs) def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None: device = kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device @@ -453,20 +441,15 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non ] else: attention_mask = None - if not self._config.cross_document_attention: - seq_ids = torch.stack( - [ - torch.cat([torch.full((x,), i, device=device) for i, x in enumerate(sample_lens)]) - for sample_lens in kwargs[AttentionKwargs.sequence_lengths] - ] - ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None])[ - :, None, sequence_k - sequence_q : sequence_k, None, :sequence_k - ] - if attention_mask is None: - attention_mask = document_mask - else: - attention_mask = attention_mask & document_mask + + preprocess_for_varlen(kwargs, device, return_seq_idx=True) + document_mask = (kwargs[AttentionKwargs.seq_idx][:, None, :] == kwargs[AttentionKwargs.seq_idx][:, :, None])[ + :, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + ] + if attention_mask is None: + attention_mask = document_mask + else: + attention_mask = attention_mask & document_mask kwargs[AttentionKwargs.attention_mask] = attention_mask @@ -479,12 +462,3 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non device=device, ) kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value - - def _preprocess_for_flash_attention(self, kwargs: dict[str, typing.Any]) -> None: - if not self._config.cross_document_attention: - preprocess_for_varlen( - kwargs, - kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device, - return_cu_seqlens=True, - return_max_seqlen=True, - ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 40baf2009..a2221eff7 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -120,12 +120,6 @@ class AttentionConfig(MixerConfig): desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", hint=FieldHint.feature, ) - cross_document_attention: bool = Field( - default=True, - desc="Allow for cross-document attention.", - doc="Disable to prevent attention between tokens belonging to different documents.", - hint=FieldHint.feature, - ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 4f8595250..bf35765d0 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -35,6 +35,7 @@ class BlockKwargs: sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" token_dim = "token_dim" + num_tokens = "num_tokens" hidden_token_dim = "hidden_token_dim" # TODO: These are confusing sequence_length = "sequence_length" diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 2e7425343..eacc04611 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -7,10 +7,11 @@ from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.layers.block.block import BlockBase from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.utils import safe_merge_dicts class FixedBlockSequence[ConfigType: FixedBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): @@ -61,6 +62,9 @@ def _layers_with_namespace(self) -> list[Layer]: def get_layers(self) -> list["Layer"]: return self._layers_with_namespace + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return self._layers_with_namespace[0].get_preprocessing_config(phase) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._layers_with_namespace[0].preprocess(kwargs) @@ -121,6 +125,12 @@ def _layers_with_namespace(self) -> list[Layer]: def get_layers(self) -> list[Layer]: return self._layers_with_namespace + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return safe_merge_dicts( + self._layers_with_namespace[index].get_preprocessing_config(phase) + for _, index in self._config.preprocessing_layers.items() + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: for _, index in self._config.preprocessing_layers.items(): self._layers_with_namespace[index].preprocess(kwargs) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index dd19c1086..4a2e066c3 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import ReduceOp, all_reduce, set_generator from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.layers.block.block import Block @@ -15,7 +15,7 @@ from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert +from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -206,6 +206,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) ) + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return safe_merge_dicts(self.mixer.get_preprocessing_config(phase), self.mlp.get_preprocessing_config(phase)) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.mixer.preprocess(kwargs) self.mlp.preprocess(kwargs) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0e54e7583..4a2422cd9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -56,13 +56,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - cross_document_position_embeddings: bool = Field( - default=True, - desc="Allow for cross-document position embeddings.", - doc="Disable to reset position ids at the beginning of each document.", - hint=FieldHint.feature, - ) - dropout: float = Field( default=0.0, desc="Dropout applied to the embedding layer.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index c6df8f62b..ed685b416 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -7,8 +7,7 @@ from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.layers.attention.preprocessing import preprocess_for_varlen +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs @@ -179,15 +178,8 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c # TODO: Add marginal compute? (embeddings) return 0 - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - if not self._config.position_embeddings.enabled: - return - # TODO: Move to data preprocessing. - if self._config.cross_document_position_embeddings: - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - kwargs[LanguageModelKwargs.position_ids] = torch.arange( - sequence_k - sequence_q, sequence_k, device=self._distributed.device, dtype=torch.int64 - ).repeat(kwargs[LanguageModelKwargs.batch_dim].size) - else: - preprocess_for_varlen(kwargs, self._distributed.device, return_position_ids=True) + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + out = {"vocab_size": self.embeddings.vocab_size} + if self._config.position_embeddings.enabled: + out["return_position_index"] = True + return out diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 32e2ccbf9..bdd261d28 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -4,11 +4,12 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.layers.block.block import BlockBase from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.utils import safe_merge_dicts logger = logging.getLogger(__name__) @@ -65,6 +66,14 @@ def get_layers(self) -> list[Layer]: layers += self.multi_token_prediction.get_layers() return layers + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return safe_merge_dicts( + self.embeddings.get_preprocessing_config(phase), + self.decoder.get_preprocessing_config(phase), + self.head.get_preprocessing_config(phase), + {} if self.multi_token_prediction is None else self.multi_token_prediction.get_preprocessing_config(phase), + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(kwargs) diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index d7665cf00..a828cacc1 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -6,7 +6,7 @@ from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.layers.block.block import BlockBase from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -87,6 +87,10 @@ def get_layers(self) -> list[Layer]: def get_output_weights(self) -> list[torch.Tensor]: return sum((head.get_output_weights() for head in self.heads), []) + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + if self._enabled: + self._layers_with_namespace[0].get_preprocessing_config(phase) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: if self._enabled: self._layers_with_namespace[0].preprocess(kwargs) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 5e721d424..5f6374820 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -8,10 +8,9 @@ from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import MixerKwargs -from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import GatedDeltaNetConfig @@ -370,13 +369,8 @@ def _forward( return output - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - preprocess_for_varlen( - kwargs, - kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, - return_cu_seqlens=True, - return_seq_idx=True, - ) + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return {"return_cumulative_sequence_lengths": True, "return_document_index": True} def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 07ca3a997..1fe56470e 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -7,10 +7,9 @@ from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import MixerKwargs -from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig @@ -290,10 +289,5 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - preprocess_for_varlen( - kwargs, - kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, - return_cu_seqlens=True, - return_seq_idx=True, - ) + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return {"return_cumulative_sequence_lengths": True, "return_document_index": True} diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index fd6255e6c..275a1fae9 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -7,10 +7,9 @@ from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import MixerKwargs -from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -167,7 +166,7 @@ def _forward( assert _mamba_available sequence_length = kwargs[BlockKwargs.sequence_q_dim].size - token_shape = (kwargs[BlockKwargs.batch_dim].size, kwargs[BlockKwargs.sequence_q_dim].size) + token_shape = (1, kwargs[BlockKwargs.sequence_q_dim].size) # inner_projection : (local_tokens, hidden) -> (batch, sequence, local_inner_projection) inner_projection = self.in_proj(input_).unflatten(0, token_shape) dt = self.dt_proj(self.dt_in_proj(input_)).unflatten(0, token_shape) @@ -250,14 +249,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c # TODO: Implement. raise NotImplementedError() - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: if not self._config.cross_document_attention: assert ( _mamba_varlen_available ), f"Varlen mamba requires custom mamba installation from `https://github.com/jxiw/varlen_mamba`" - preprocess_for_varlen( - kwargs, - kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, - return_seq_idx=True, - return_position_ids=True, - ) + return {"return_position_index": True, "return_document_index": True} diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py index 1bd499f97..a014f6f5a 100644 --- a/fast_llm/layers/vision/vision_encoder.py +++ b/fast_llm/layers/vision/vision_encoder.py @@ -5,11 +5,12 @@ from fast_llm.engine.base_model.base_model import Layer, LayerBaseWithNamespace from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.layers.block.block import BlockBase from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.vision.config import VisionEncoderConfig, VisionMultiModalModelConfig +from fast_llm.utils import safe_merge_dicts logger = logging.getLogger(__name__) @@ -53,6 +54,14 @@ def __init__( def get_layers(self) -> list["Layer"]: return self.embeddings.get_layers() + self.encoder.get_layers() + self.adapter.get_layers() + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + # Needed because the base class uses `get_layers` which may bypass the decoder. TODO: Avoidable? + return safe_merge_dicts( + self.embeddings.get_preprocessing_config(phase), + self.encoder.get_preprocessing_config(phase), + self.adapter.get_preprocessing_config(phase), + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder. TODO: Avoidable? self.embeddings.preprocess(kwargs) @@ -98,6 +107,12 @@ def __init__( def get_layers(self) -> list[Layer]: return self._vision_encoder_with_namespace.get_layers() + super().get_layers() + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return safe_merge_dicts( + self._vision_encoder_with_namespace.get_preprocessing_config(phase), + super().get_preprocessing_config(phase), + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._vision_encoder_with_namespace.preprocess(kwargs) super().preprocess(kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 698f624ed..e32b78ff9 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -5,19 +5,18 @@ import torch -from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.batch.language_model import LanguageModelBatchNew from fast_llm.engine.base_model.base_model import BaseModel -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -41,116 +40,9 @@ def __init__( Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa - def preprocess_meta( - self, batch_meta: GPTBatchConfig | LanguageModelBatch, phase: PhaseType - ) -> list[tuple[TensorMeta, dict]]: - # TODO Remove (Move batch splitting elsewhere) - # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence - - if isinstance(batch_meta, GPTBatchConfig): - micro_batch_size = batch_meta.micro_batch_size - sequence_length = batch_meta.sequence_length - micro_sequence_length = batch_meta.micro_sequence_length - truncate_documents = batch_meta.truncate_documents - else: - micro_batch_size, sequence_length = batch_meta.tokens.tokens.shape - if phase != PhaseType.inference: - sequence_length -= self._config.head.prediction_heads - micro_sequence_length = sequence_length - truncate_documents = True - - batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) - - if micro_sequence_length is None: - micro_sequence_length = sequence_length - else: - Assert.multiple(sequence_length, micro_sequence_length) - - # TODO: Calculate hidden dims elsewhere? - sequence_q_dim = TensorDim( - BlockDimNames.sequence_q, - micro_sequence_length, - self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), - ) - token_dim = TensorDim( - "token", - batch_dim.global_size * sequence_q_dim.global_size, - self._distributed_config.get_distributed_dim(DistributedDimNames.data), - ) - # The token dimension as appears in hidden states, i.e. with possible sequence-tensor-parallel split. - hidden_token_dim = ( - ( - "token_tp", - token_dim.global_size, - self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), - ) - if self._distributed_config.sequence_tensor_parallel - else token_dim - ) - - common_kwargs = { - LanguageModelKwargs.phase: phase, - AttentionKwargs.sequence_length: sequence_length, - AttentionKwargs.batch_dim: batch_dim, - AttentionKwargs.sequence_q_dim: sequence_q_dim, - AttentionKwargs.token_dim: token_dim, - AttentionKwargs.hidden_token_dim: hidden_token_dim, - LanguageModelKwargs.mask_inputs: not truncate_documents, - } - - sequence_k_pasts = range( - sequence_q_dim.size * self._distributed_config.sequence_data_rank, - sequence_length, - micro_sequence_length, - ) - reference_preprocessed_metas = {} - for name, reference_model in self._reference_models.items(): - reference_preprocessed_metas[name] = reference_model.fast_llm_model.base_model.preprocess_meta( - batch_meta, PhaseType.inference - ) - Assert.eq(len(reference_preprocessed_metas[name]), len(sequence_k_pasts)) - - preprocessed_meta = [] - for i, sequence_k_past in enumerate(sequence_k_pasts): - sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(BlockDimNames.sequence_k, sequence_k) - - tokens = TensorMeta.from_dims( - (token_dim,), tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 - ) - - kwargs = { - **common_kwargs, - AttentionKwargs.sequence_k_dim: sequence_k_dim, - } - if phase != PhaseType.inference: - kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( - (token_dim,), tensor_name="labels", dtype=torch.int64 - ) - reference_kwargs = {} - for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): - reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] - for key in ( - AttentionKwargs.sequence_length, - AttentionKwargs.batch_dim, - AttentionKwargs.sequence_q_dim, - AttentionKwargs.sequence_k_dim, - AttentionKwargs.token_dim, - AttentionKwargs.hidden_token_dim, - ): - Assert.eq(reference_kwargs_[key], kwargs[key]) - reference_kwargs[name] = reference_kwargs_ - kwargs["reference_models"] = reference_kwargs - - preprocessed_meta.append((tokens, kwargs)) - - return preprocessed_meta - def preprocess_batch( self, - batch: LanguageModelBatch, - preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, + batches: list[LanguageModelBatchNew], *, phase: PhaseType, iteration: int, @@ -160,79 +52,53 @@ def preprocess_batch( # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup - batch.to_device_(self._distributed.device) - - if preprocessed_meta is None: - preprocessed_meta = self.preprocess_meta(batch, phase) - reference_preprocessed_batches = {} for name, reference_model in self._reference_models.items(): - reference_preprocessed_meta = [ - (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta - ] reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( - batch, - reference_preprocessed_meta, + batches, phase=PhaseType.inference, iteration=iteration, ) preprocessed = [] presents = None - for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - tokens_end = kwargs_meta[AttentionKwargs.sequence_k_dim].size - tokens_begin = tokens_end - kwargs_meta[AttentionKwargs.sequence_q_dim].size - cropped_tokens = batch.tokens.crop(tokens_begin, tokens_end) - - # TODO: Add pasts/presents to meta input? - # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. + for micro_sequence_index, batch in enumerate(batches): pasts = presents - presents = None if i == len(preprocessed_meta) - 1 else [] - + presents = None if micro_sequence_index == len(batches) - 1 else [] + batch.to_device_(self._distributed.device) kwargs: dict[str, typing.Any] = { - **kwargs_meta, + LanguageModelKwargs.phase: phase, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, - BlockKwargs.iteration: iteration, - AttentionKwargs.sequence_lengths: cropped_tokens.lengths, - AttentionKwargs.device: self._distributed.device, - BlockKwargs.output_hidden_states: [], - BlockKwargs.hidden_states: {}, + LanguageModelKwargs.iteration: iteration, + LanguageModelKwargs.device: self._distributed.device, + LanguageModelKwargs.output_hidden_states: [], + LanguageModelKwargs.hidden_states: {}, + LanguageModelKwargs.token_dim: batch.token_dim, + LanguageModelKwargs.hidden_token_dim: batch.hidden_token_dim, + LanguageModelKwargs.sequence_k_dim: batch.sequence_k_dim, + LanguageModelKwargs.num_tokens: batch.num_tokens, + LanguageModelKwargs.sequence_length: batch.sequence_length, + LanguageModelKwargs.sequence_lengths: batch.document_lengths, + LanguageModelKwargs.labels: batch.labels, + LanguageModelKwargs.loss_mask: batch.prediction_masks, + AttentionKwargs.cu_seqlens_q: batch.cumulative_lengths_q, + AttentionKwargs.cu_seqlens_k: batch.cumulative_lengths_k, + AttentionKwargs.max_seqlen_q: batch.max_length_q, + AttentionKwargs.max_seqlen_k: batch.max_length_k, + LanguageModelKwargs.seq_idx: batch.document_index, + LanguageModelKwargs.position_ids: batch.position_index, + LanguageModelKwargs.chosen_spans: batch.chosen_spans, + LanguageModelKwargs.rejected_spans: batch.rejected_spans, } if extra_kwargs is not None: Assert.empty(kwargs.keys() & extra_kwargs.keys()) kwargs.update(extra_kwargs) - - # TODO: Simplify, check more carefully if needed. - if self._decoder_reference_models: - # Create activation mask for activation distillation - # This mask should: - # - Be 0 on padding tokens (added at the end when documents aren't truncated) - # - Be 1 on image placeholder tokens (token value -100 but not padding) - # - Be 1 on all other valid tokens (ignores loss-masking-spans) - # - # Note: Padding is added as a separate document with all tokens = -100 - # We detect padding by checking if all tokens in a document segment are -100 - activation_mask = torch.ones_like(cropped_tokens.tokens, dtype=torch.bool) - - for sample_index, sample_lengths in enumerate(cropped_tokens.lengths): - # Iterate through documents in this sample - pos = 0 - for doc_length in sample_lengths: - # Check if this document is padding (all tokens are -100) - doc_tokens = cropped_tokens.tokens[sample_index, pos : pos + doc_length] - is_padding_doc = torch.all(doc_tokens == -100).item() - - if is_padding_doc: - # This is a padding document, mask it out - activation_mask[sample_index, pos : pos + doc_length] = False - - pos += doc_length - - kwargs[BlockKwargs.activation_mask] = activation_mask.flatten() + if phase == PhaseType.inference: + kwargs[BlockKwargs.output_hidden_states].append(re.compile(r"head\..*logits.*$")) for name, reference_model in self._reference_models.items(): - reference_tokens, reference_kwargs = reference_preprocessed_batches[name][i] + reference_tokens, reference_kwargs = reference_preprocessed_batches[name][micro_sequence_index] if name in self._decoder_reference_models: # TODO: Get the actual names reference_kwargs[BlockKwargs.output_hidden_states].append( @@ -245,40 +111,8 @@ def preprocess_batch( layer_name: tensor for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() } - - if phase == PhaseType.inference: - kwargs[BlockKwargs.output_hidden_states].append(re.compile(r"head\..*logits.*$")) - else: - labels_begin = tokens_begin + 1 - labels_end = tokens_end + self._config.head.prediction_heads - labels = batch.tokens.crop(labels_begin, labels_end).tokens - - if batch.loss_masking_spans is not None: - loss_masking_spans = batch.loss_masking_spans.crop(labels_begin, labels_end) - loss_mask = torch.ones_like(labels, dtype=torch.bool) - for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): - for begin, end in loss_masking_spans: - loss_mask[sample_index, begin:end] = False - labels = torch.where(loss_mask, labels, -100) - - labels = labels.flatten(0, 1) - kwargs[LanguageModelKwargs.labels] = labels - - if self._config.head.get_reference_models(): # loss masks only used for distillation currently - # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders - kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 - - if batch.chosen_spans is not None: - kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges - - if batch.rejected_spans is not None: - kwargs[LanguageModelKwargs.rejected_spans] = batch.rejected_spans.crop( - labels_begin, labels_end - ).ranges - - tokens = cropped_tokens.tokens.flatten(0, 1) self.preprocess(kwargs) - preprocessed.append((tokens, kwargs)) + preprocessed.append((batch.tokens, kwargs)) return preprocessed diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index ef4956176..df7f78643 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -1,9 +1,10 @@ import logging import typing +from fast_llm.batch.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -31,13 +32,12 @@ def _get_sampling_parameters( return parameters if _return_dict else SamplingParameters(**parameters) def _get_preprocessing_config( - self, *, _return_dict: bool = False - ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: - + self, phase: PhaseType, *, _return_dict: bool = False + ) -> LanguageModelBatchPreprocessingConfig | dict[str, typing.Any]: out = { - "type": "language_model", - "vocab_size": self._config.model.base_model.embeddings.vocab_size, + "phase": phase, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "use_preference_spans": self._config.batch.use_preference_spans, + **self._multi_stage.base_model.get_preprocessing_config(phase), } - return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) + return out if _return_dict else LanguageModelBatchPreprocessingConfig.from_dict(out) diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 307a67c63..d7bff8477 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -53,7 +53,6 @@ def import_config(cls, config: dict) -> dict: "head_size": config["head_size"], "add_linear_biases": config["add_linear_biases"], "causal": config["causal"], - "cross_document_attention": config["cross_document_attention"], } @classmethod @@ -74,7 +73,7 @@ def export_config(cls, config: AttentionConfig) -> dict: "head_size": config.head_size, "add_linear_biases": config.add_linear_biases, "causal": config.causal, - "cross_document_attention": config.cross_document_attention, + "cross_document_attention": False, "rotary": { "type": rotary_type, "theta": config.rotary.theta, diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index a75d732b8..8af22e065 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -56,7 +56,6 @@ def import_config(cls, config: dict) -> dict: out = super().import_config(config) out["rotary"]["type"] = "default_2d" out["causal"] = False - out["cross_document_attention"] = False return out @classmethod @@ -66,7 +65,6 @@ def export_config(cls, config: AttentionConfig) -> dict: Assert.is_(type(config.rotary), Rotary2DConfig) assert not config.add_linear_biases assert not config.causal - assert not config.cross_document_attention Assert.eq(config.head_groups, config.heads) return { "num_attention_heads": config.heads, diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index e90bd4d89..87d8f3310 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -5,7 +5,7 @@ from fast_llm.core.distributed import all_gather_scalar from fast_llm.data.sample.language_model import LanguageModelBatch -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs @@ -133,7 +133,6 @@ def preprocess_meta( ) kwargs[self._vision_encoder_namespace] = { VisionKwargs.sequence_length: kwargs[VisionKwargs.sequence_length], - VisionKwargs.batch_dim: scalar_dim, VisionKwargs.sequence_q_dim: token_dim, VisionKwargs.sequence_k_dim: token_dim, VisionKwargs.token_dim: token_dim, diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 924c2cc7f..fa7207926 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -8,10 +8,9 @@ from fast_llm.utils import Assert -@pytest.mark.parametrize("cross_document_attention", (True, False)) @pytest.mark.parametrize(("causal", "window_size"), ((True, None), (True, 50), (False, None))) @pytest.mark.skipif(not _flash_available, reason="Flash attention not available") -def test_attention_implementations(cross_document_attention: bool, causal: bool, window_size: int | None): +def test_attention_implementations(causal: bool, window_size: int | None): """ Check that the flash and backup attention implementation give the same result. """ @@ -21,7 +20,6 @@ def test_attention_implementations(cross_document_attention: bool, causal: bool, heads=4, head_groups=2, window_size=window_size, - cross_document_attention=cross_document_attention, causal=causal, ).get_layer( DistributedConfig(compute_dtype="bfloat16"), diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index a8ae85c12..c14232b4f 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,7 +6,6 @@ import torch from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead @@ -88,8 +87,6 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: ) label_shape = (BATCH_SIZE * (SEQUENCE_LENGTH + self.prediction_heads - 1),) kwargs: dict[str, typing.Any] = { - AttentionKwargs.batch_dim: TensorDim("batch", BATCH_SIZE), - AttentionKwargs.sequence_q_dim: TensorDim("sequence_q", SEQUENCE_LENGTH), AttentionKwargs.grad_output: 1.0, } if self.loss_masking: diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index d31cffa50..d262e414c 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -20,14 +20,13 @@ @pytest.mark.parametrize( "config", [ - AttentionConfig(heads=4, head_groups=2, head_size=16, cross_document_attention=False), + AttentionConfig(heads=4, head_groups=2, head_size=16), pytest.param( MambaConfig( d_inner=128, d_xb=64, state_size=16, dt_rank=8, - cross_document_attention=False, ), marks=pytest.mark.skip("Mamba varlen kernel not available"), ), @@ -73,7 +72,6 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): **kwargs, BlockKwargs.sequence_lengths: sequence_lengths, BlockKwargs.sequence_length: seq_len, - BlockKwargs.batch_dim: TensorDim("", batch_size), BlockKwargs.sequence_q_dim: TensorDim("", seq_len), BlockKwargs.sequence_k_dim: TensorDim("", seq_len), } diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 40dbb7d29..b5b74fb9e 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -240,7 +240,6 @@ def update_and_add_testing_config( "heads": 8, "head_groups": 8, "head_size": 32, - # "cross_document_attention":False, }, "mlp": { "layer_1": {"weight": init_1}, @@ -711,7 +710,6 @@ def update_and_add_testing_config( ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "rotary", "type"): "default_2d", ("model", "base_model", "vision_encoder", "encoder", "num_blocks"): 1, ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "causal"): False, - ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "cross_document_attention"): False, # Pixtal doesn't support GQA ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "head_groups"): 8, }, @@ -932,7 +930,6 @@ def update_and_add_testing_config( ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "rotary", "type"): "default_2d", ("model", "base_model", "vision_encoder", "encoder", "num_blocks"): 1, ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "causal"): False, - ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "cross_document_attention"): False, # Pixtral doesn't support GQA ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "head_groups"): 8, }, From 295c25bfa88b7c856a33815071e89e8c1799b685 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 18 Feb 2026 18:32:49 -0500 Subject: [PATCH 2/4] stuff --- fast_llm/data/auto.py | 10 +- fast_llm/{ => data}/batch/__init__.py | 0 fast_llm/{ => data}/batch/config.py | 42 +- fast_llm/{ => data}/batch/language_model.py | 66 ++- fast_llm/data/data/abstract.py | 5 +- fast_llm/data/data/gpt/config.py | 4 +- fast_llm/data/data/gpt/data.py | 34 +- fast_llm/data/dataset/abstract.py | 14 +- fast_llm/data/dataset/blended.py | 9 +- fast_llm/data/dataset/config.py | 74 +-- fast_llm/data/dataset/gpt/config.py | 32 +- fast_llm/data/dataset/gpt/fim.py | 34 +- fast_llm/data/dataset/gpt/legacy_memmap.py | 24 +- fast_llm/data/dataset/gpt/random.py | 33 +- fast_llm/data/dataset/indexed.py | 18 +- .../{sample => dataset/memmap}/__init__.py | 0 fast_llm/data/dataset/memmap/abstract.py | 119 ++++ fast_llm/data/dataset/memmap/config.py | 462 ++++++++++++++++ .../data/dataset/memmap/language_model.py | 237 ++++++++ fast_llm/data/dataset/{ => memmap}/memmap.py | 25 +- fast_llm/data/dataset/memmap/patch.py | 141 +++++ fast_llm/data/dataset/memmap/range.py | 73 +++ fast_llm/data/dataset/memmap/token.py | 95 ++++ fast_llm/data/dataset/monitor.py | 8 +- fast_llm/data/dataset/sampled.py | 48 +- fast_llm/data/document/__init__.py | 0 fast_llm/data/document/abstract.py | 23 + fast_llm/data/document/language_model.py | 90 +++ fast_llm/data/document/patch.py | 66 +++ fast_llm/data/document/range.py | 37 ++ fast_llm/data/document/token.py | 105 ++++ .../preparator/dataset_discovery/prepare.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 45 +- fast_llm/data/sample/abstract.py | 270 --------- fast_llm/data/sample/language_model.py | 511 ------------------ fast_llm/data/sample/patch.py | 359 ------------ fast_llm/data/sample/range.py | 173 ------ fast_llm/data/sample/token.py | 265 --------- fast_llm/models/gpt/huggingface.py | 2 - fast_llm/models/gpt/model.py | 46 +- fast_llm/models/gpt/trainer.py | 2 +- fast_llm/models/multimodal/huggingface.py | 1 - fast_llm/models/multimodal/model.py | 1 - tests/data/common.py | 31 +- tests/data/test_blending.py | 6 +- tests/data/test_concatenate.py | 4 +- tests/data/test_image_patch.py | 13 +- tests/data/test_loss_masking_spans.py | 11 +- tests/data/test_preference_spans.py | 6 +- tests/data/test_preparator.py | 5 +- tests/data/test_sampling.py | 14 +- tests/data/test_slice.py | 4 +- tests/models/test_match_megatron.py | 39 +- tests/test_loss_mask.py | 3 - 54 files changed, 1830 insertions(+), 1911 deletions(-) rename fast_llm/{ => data}/batch/__init__.py (100%) rename fast_llm/{ => data}/batch/config.py (75%) rename fast_llm/{ => data}/batch/language_model.py (73%) rename fast_llm/data/{sample => dataset/memmap}/__init__.py (100%) create mode 100644 fast_llm/data/dataset/memmap/abstract.py create mode 100644 fast_llm/data/dataset/memmap/config.py create mode 100644 fast_llm/data/dataset/memmap/language_model.py rename fast_llm/data/dataset/{ => memmap}/memmap.py (85%) create mode 100644 fast_llm/data/dataset/memmap/patch.py create mode 100644 fast_llm/data/dataset/memmap/range.py create mode 100644 fast_llm/data/dataset/memmap/token.py create mode 100644 fast_llm/data/document/__init__.py create mode 100644 fast_llm/data/document/abstract.py create mode 100644 fast_llm/data/document/language_model.py create mode 100644 fast_llm/data/document/patch.py create mode 100644 fast_llm/data/document/range.py create mode 100644 fast_llm/data/document/token.py delete mode 100644 fast_llm/data/sample/abstract.py delete mode 100644 fast_llm/data/sample/language_model.py delete mode 100644 fast_llm/data/sample/patch.py delete mode 100644 fast_llm/data/sample/range.py delete mode 100644 fast_llm/data/sample/token.py diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index f400978bf..2e89695b3 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -6,9 +6,16 @@ BlendedDatasetConfig, ConcatenatedDatasetConfig, DatasetSliceConfig, - MemmapDatasetConfig, SampledDatasetUpdateConfig, ) +from fast_llm.data.dataset.memmap.config import ( # isort: skip + LanguageModelReaderConfig, + MemmapDatasetConfig, + NullReaderConfig, + PatchReaderConfig, + RangeReaderConfig, + TokenReaderConfig, +) from fast_llm.data.dataset.gpt.config import ( # isort: skip GPTDatasetFromFileConfig, GPTFimSampledDatasetConfig, @@ -16,4 +23,3 @@ ) from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig # isort: skip from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip -from fast_llm.data.sample.abstract import NullReaderConfig # isort: skip diff --git a/fast_llm/batch/__init__.py b/fast_llm/data/batch/__init__.py similarity index 100% rename from fast_llm/batch/__init__.py rename to fast_llm/data/batch/__init__.py diff --git a/fast_llm/batch/config.py b/fast_llm/data/batch/config.py similarity index 75% rename from fast_llm/batch/config.py rename to fast_llm/data/batch/config.py index f857d115b..a3d192bae 100644 --- a/fast_llm/batch/config.py +++ b/fast_llm/data/batch/config.py @@ -1,30 +1,35 @@ +import dataclasses import functools import logging import typing -from fast_llm.config import Field, FieldUpdate, config_class +from fast_llm.config import Field, config_class +from fast_llm.data.document.abstract import Document from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImagePatchConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig -from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + import torch + logger = logging.getLogger(__name__) -@config_class(registry=True) +@config_class() class BatchPreprocessingConfig(PreprocessingConfig): - batch: BatchConfig = Field() + pass -@config_class(dynamic_type={PreprocessingConfig: "language_model"}) +@config_class() class LanguageModelBatchPreprocessingConfig(LanguageModelPreprocessingConfig): _abstract = False # TODO: Duplicate `use_loss_masking_spans`, `use_preference_spans` - batch: GPTBatchConfig = FieldUpdate() + batch: GPTBatchConfig = Field() phase: PhaseType = Field(default=PhaseType.inference) predicted_tokens: int = Field(default=1) return_cumulative_sequence_lengths: bool = Field(default=False) @@ -52,3 +57,28 @@ def check_compatibility(self, preprocessing: typing.Self) -> None: assert self.use_preference_spans, "The dataset is missing required preference spans" if preprocessing.use_image_patches and self.use_image_patches: self.image_patches.check_compatibility(preprocessing.image_patches) + + +@dataclasses.dataclass +class MicroBatch: + pass + + +@dataclasses.dataclass +class PreprocessedBatch: + micro_batches: list[MicroBatch] + + +@config_class(registry=True) +class BatchPreprocessingConfig(PreprocessingConfig): + batch: BatchConfig = Field() + + @classmethod + def from_documents( + cls, + config: BatchPreprocessingConfig, + distributed_config: DistributedConfig, + documents: list[Document], + device: "torch.device | None" = None, + ) -> typing.Self: + pass diff --git a/fast_llm/batch/language_model.py b/fast_llm/data/batch/language_model.py similarity index 73% rename from fast_llm/batch/language_model.py rename to fast_llm/data/batch/language_model.py index 7de5c07e3..b0f67fc1c 100644 --- a/fast_llm/batch/language_model.py +++ b/fast_llm/data/batch/language_model.py @@ -3,14 +3,14 @@ import torch -from fast_llm.batch.config import LanguageModelBatchPreprocessingConfig -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig, MicroBatch, PreprocessedBatch +from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @dataclasses.dataclass -class LanguageModelBatchNew: +class LanguageModelMicroBatch(MicroBatch): tokens: torch.Tensor token_dim: TensorDim hidden_token_dim: TensorDim @@ -27,8 +27,7 @@ class LanguageModelBatchNew: max_length_k: torch.Tensor | None = None document_index: torch.Tensor | None = None position_index: torch.Tensor | None = None - chosen_spans: list[tuple[int, int]] | None = None - rejected_spans: list[tuple[int, int]] | None = None + # TODO: ====== Preference spans? ====== def to_device_(self, device: torch.device): self.tokens = self.tokens.to(device, non_blocking=True) @@ -45,24 +44,37 @@ def to_device_(self, device: torch.device): if self.position_index is not None: self.position_index = self.position_index.to(device, non_blocking=True) + +@dataclasses.dataclass +class LanguageModelPreprocessedBatch(PreprocessedBatch): + micro_batches: list[LanguageModelMicroBatch] + @classmethod def from_documents( cls, + documents: list[LanguageModelDocument], + *, config: LanguageModelBatchPreprocessingConfig, distributed_config: DistributedConfig, - documents: list[LanguageModelSample], device: torch.device | None = None, - ) -> list[typing.Self]: - num_tokens = sum(len(document) for document in documents) - padding = config.batch.sequence_length + config.predicted_tokens - num_tokens - sample = LanguageModelSample.from_documents(documents + [documents[0].get_padding(padding)]) - # sample.tokens.lengths - # lengths = [len(document) for document in documents] - # num_tokens = sum(lengths) + ) -> typing.Self: + batch = LanguageModelBatch.from_documents( + documents, pad_to_size=config.batch.sequence_length + config.predicted_tokens + ) + return cls.from_batch(batch, config=config, distributed_config=distributed_config, device=device) + @classmethod + def from_batch( + cls, + batch: LanguageModelBatch, + *, + config: LanguageModelBatchPreprocessingConfig, + distributed_config: DistributedConfig, + device: torch.device | None = None, + ) -> typing.Self: if device is None: - device = sample.tokens.tokens.device - sample.to_device_(device) + device = batch.tokens.tokens.device + batch.to_device_(device) token_dim = TensorDim( "token", @@ -88,19 +100,16 @@ def from_documents( ): sequence_k = sequence_k_past + token_dim.size sequence_k_dim = TensorDim("sequence_k", sequence_k) - cropped_sample = sample.crop(sequence_k_past, sequence_k) - - # document_lengths, cumulative_lengths_q, cumulative_lengths_k, first_document_index, remaining_tokens = crop_lengths( - # sample.tokens.lengths, sequence_k_past, sequence_k_past + token_dim.size) + cropped_sample = batch.crop(sequence_k_past, sequence_k) - micro_batch = LanguageModelBatchNew( - tokens=sample.tokens.tokens[sequence_k_past:sequence_k], + micro_batch = LanguageModelMicroBatch( + tokens=batch.tokens.tokens[sequence_k_past:sequence_k], token_dim=token_dim, hidden_token_dim=hidden_token_dim, sequence_k_dim=sequence_k_dim, - num_tokens=min(sequence_k, num_tokens) - sequence_k_past, + num_tokens=min(sequence_k, batch.num_tokens) - sequence_k_past, sequence_length=config.batch.sequence_length, - document_lengths=sample.tokens.lengths, + document_lengths=batch.tokens.lengths, ) if config.return_cumulative_sequence_lengths: micro_batch.cumulative_lengths_q, micro_batch.cumulative_lengths_k = ( @@ -112,19 +121,16 @@ def from_documents( micro_batch.document_index = cropped_sample.tokens.get_document_index() if config.return_position_index: micro_batch.position_index = cropped_sample.tokens.get_position_index() - if config.use_preference_spans: - micro_batch.chosen_spans = cropped_sample.chosen_spans.ranges - micro_batch.rejected_spans = cropped_sample.rejected_spans.ranges for prediction_distance in range(1, config.predicted_tokens + 1): label_begin = sequence_k_past + prediction_distance label_end = sequence_k + prediction_distance - label_tokens = sample.tokens.crop(label_begin, label_end) + label_tokens = batch.tokens.crop(label_begin, label_end) labels = label_tokens.tokens.clone() # Apply loss masking spans. - if config.use_loss_masking_spans: - for span_begin, span_end in sample.loss_masking_spans.crop(label_begin, label_end).ranges: + if config.use_loss_masking_spans and batch.loss_masking_spans is not None: + for span_begin, span_end in batch.loss_masking_spans.crop(label_begin, label_end).ranges: labels[span_begin:span_end] = -100 # Mask cross-document predictions. @@ -141,4 +147,4 @@ def from_documents( micro_batch.prediction_masks.append(labels > 0) micro_batches.append(micro_batch) - return micro_batches + return LanguageModelPreprocessedBatch(micro_batches=micro_batches) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e01331be2..c5400b6c7 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -3,10 +3,10 @@ import typing from fast_llm.config import Configurable +from fast_llm.data.batch.config import PreprocessedBatch from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert @@ -54,5 +54,6 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[Batch]: + preprocess: bool = True, + ) -> typing.Iterator[PreprocessedBatch]: pass diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index ba5be883a..914699b74 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -8,7 +8,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.data.sample.language_model import LanguageModelSample + from fast_llm.data.document.language_model import LanguageModelDocument logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field( + datasets: dict[str, SampledDatasetConfig["LanguageModelDocument"]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 3a1e99e6d..ff1fbd3bc 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,3 +1,4 @@ +import functools import logging import pathlib import typing @@ -7,6 +8,8 @@ import torch.utils.data from fast_llm.core.distributed import safe_barrier +from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig +from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch from fast_llm.data.data.abstract import Data from fast_llm.data.data.data_loader import SampledDatasetIterator from fast_llm.data.data.gpt.config import GPTDataConfig @@ -14,8 +17,7 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig -from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -32,7 +34,7 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, SamplingParameters] - _preprocessing: dict[str, LanguageModelPreprocessingConfig] + _preprocessing: dict[str, LanguageModelBatchPreprocessingConfig] _is_setup: bool = False def __init__( @@ -50,7 +52,7 @@ def setup( self, distributed: "Distributed", sampling_parameters: dict[str, SamplingParameters], - preprocessing: dict[str, LanguageModelPreprocessingConfig], + preprocessing: dict[str, LanguageModelBatchPreprocessingConfig], cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: @@ -105,7 +107,8 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[LanguageModelBatch]: + preprocess: bool = True, + ) -> typing.Iterator[LanguageModelPreprocessedBatch]: assert self._is_setup # Some dataset names may come from phases and are capitalized, @@ -130,7 +133,26 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=LanguageModelBatch.from_samples, + collate_fn=functools.partial(self._collate_fn, dataset_name=dataset_name, preprocess=preprocess), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) + + def _collate_fn( + self, + documents: list[list[LanguageModelDocument]], + dataset_name: str, + preprocess: bool = True, + ) -> LanguageModelPreprocessedBatch | LanguageModelBatch: + documents = [document for documents_ in documents for document in documents_] + config = self._preprocessing[dataset_name] + if preprocess: + return LanguageModelPreprocessedBatch.from_documents( + documents, + config=config, + distributed_config=self._distributed_config, + ) + else: + return LanguageModelBatch.from_documents( + documents, pad_to_size=config.batch.sequence_length + config.predicted_tokens + ) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 33942708b..ee34b64fc 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,13 +1,13 @@ import abc import typing -from fast_llm.data.sample.abstract import Sample +from fast_llm.data.document.abstract import Document if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData -class Dataset[SampleType: Sample](abc.ABC): +class Dataset[DocumentType: Document](abc.ABC): """ A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. """ @@ -21,21 +21,21 @@ def name(self) -> str: def __getstate__(self): state = super().__getstate__() - # Pickling sometimes fails with bound `SampleType`. + # Pickling sometimes fails with bound `DocumentType`. # This is not needed at runtime, so we just drop it. if "__orig_class__" in state: del state["__orig_class__"] return state -class SampledDataset[SampleType: Sample](Dataset[SampleType]): +class SampledDataset[DocumentType: Document](Dataset[DocumentType]): """ A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. (See the `Sampler` class below.) """ @abc.abstractmethod - def __getitem__(self, index: int) -> SampleType: + def __getitem__(self, index: int) -> list[DocumentType]: pass @abc.abstractmethod @@ -43,8 +43,8 @@ def __len__(self) -> int: pass -class SamplableDataset[SampleType: Sample](Dataset[SampleType]): +class SamplableDataset[DocumentType: Document](Dataset[DocumentType]): @abc.abstractmethod - def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: + def sample(self, config: "SamplingData") -> SampledDataset[DocumentType]: pass diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 264eb373d..0cae40656 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -4,13 +4,13 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingData -from fast_llm.data.sample.abstract import Sample +from fast_llm.data.document.abstract import Document from fast_llm.utils import Assert, normalize_probabilities logger = logging.getLogger(__name__) -class BlendedDataset[SampleType: Sample](SampledDataset[SampleType]): +class BlendedDataset[DocumentType: Document](SampledDataset[DocumentType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -21,7 +21,7 @@ class BlendedDataset[SampleType: Sample](SampledDataset[SampleType]): def __init__( self, name: str, - datasets: list[SampledDataset[SampleType]], + datasets: list[SampledDataset[DocumentType]], weights: list[float], sampling_config: SamplingData, ): @@ -35,7 +35,7 @@ def __init__( def __len__(self) -> int: return self._num_samples - def __getitem__(self, index: int) -> SampleType: + def __getitem__(self, index: int) -> list[DocumentType]: """ Blending is typically done in one of the following iterative way (ex. in Megatron datasets): ```python @@ -56,6 +56,7 @@ def __getitem__(self, index: int) -> SampleType: sampled = self._get_sampled(index) # Then get the present sample. dataset_index = self._get_next_dataset(index, sampled) + # TODO: ====== Can we mix documents from multiple datasets? ====== return self._datasets[dataset_index][sampled[dataset_index].item()] def _get_sampled(self, num_samples: int) -> torch.Tensor: diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 2858d8d18..1e1fece26 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -9,8 +9,8 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.document.abstract import Document from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: @@ -69,6 +69,10 @@ class SamplingParameters: # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 + @functools.cached_property + def total_length(self) -> int: + return self.sequence_length + self.extra_tokens + @dataclasses.dataclass(kw_only=True) class SamplingData: @@ -99,37 +103,37 @@ def get_next_rank(self) -> int: @config_class() -class DatasetConfig[SampleType: Sample](Config): +class DatasetConfig[DocumentType: Document](Config): _abstract: typing.ClassVar[bool] = True @config_class(registry=True) -class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): +class SampledDatasetConfig[DocumentType: Document](DatasetConfig[DocumentType]): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[DocumentType]: raise NotImplementedError() @config_class() -class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): - def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: +class SamplableDatasetConfig[DocumentType: Document](SampledDatasetConfig[DocumentType]): + def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[DocumentType]: raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[DocumentType]: return self.build(sampling.preprocessing).sample(sampling) @config_class() -class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): - def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]": +class IndexedDatasetConfig[DocumentType: Document](SamplableDatasetConfig[DocumentType]): + def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[DocumentType]": raise NotImplementedError() @config_class(dynamic_type={SampledDatasetConfig: "concatenated"}) -class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): +class ConcatenatedDatasetConfig[DocumentType: Document](SamplableDatasetConfig[DocumentType]): """ Concatenate multiple indexed datasets as if they were one. TODO: Make a post-sampling version? (staged training) @@ -141,7 +145,7 @@ class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[Sampl desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[IndexedDatasetConfig[SampleType]] = Field( + datasets: list[IndexedDatasetConfig[DocumentType]] = Field( default_factory=list, desc="The datasets to concatenate.", hint=FieldHint.core, @@ -155,7 +159,7 @@ def build(self, preprocessing: PreprocessingConfig) -> "ConcatenatedDataset": @config_class(dynamic_type={SampledDatasetConfig: "slice"}) -class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): +class DatasetSliceConfig[DocumentType: Document](SamplableDatasetConfig[DocumentType]): """ Use a fraction of an indexed dataset, specified by the range (begin, end). Typically used to subsample a dataset, or to reserve part of the dataset for validation and/or testing. @@ -165,7 +169,7 @@ class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]) """ _abstract = False - dataset: IndexedDatasetConfig[SampleType] = Field( + dataset: IndexedDatasetConfig[DocumentType] = Field( default=None, desc="The dataset to split.", hint=FieldHint.core, @@ -186,7 +190,7 @@ def build(self, preprocessing: PreprocessingConfig) -> "DatasetSlice": dataset = self.dataset.build(preprocessing) size = len(dataset) - return DatasetSlice[SampleType]( + return DatasetSlice[DocumentType]( f"{dataset.name}_{self.begin}_{self.end}", dataset, round(self.begin * size), @@ -195,7 +199,7 @@ def build(self, preprocessing: PreprocessingConfig) -> "DatasetSlice": @config_class(dynamic_type={SampledDatasetConfig: "sampled"}) -class SampledDatasetUpdateConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): +class SampledDatasetUpdateConfig[DocumentType: Document](SampledDatasetConfig[DocumentType]): """ Wrap a dataset to explicitly sample from it and optionally update its configuration parameters. Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. @@ -206,24 +210,24 @@ class SampledDatasetUpdateConfig[SampleType: Sample](SampledDatasetConfig[Sample desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) - dataset: SampledDatasetConfig[SampleType] = Field( + dataset: SampledDatasetConfig[DocumentType] = Field( desc="The dataset to sample from.", hint=FieldHint.core, ) - def build_and_sample(self, data: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample(self, data: SamplingData) -> SampledDataset[DocumentType]: return self.dataset.build_and_sample(data.update_config(self.sampling)) @config_class(dynamic_type={SampledDatasetConfig: "blended"}) -class BlendedDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): +class BlendedDatasetConfig[DocumentType: Document](SampledDatasetConfig[DocumentType]): _abstract = False name: str = Field( default="blended", desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[SampledDatasetConfig[SampleType]] = Field( + datasets: list[SampledDatasetConfig[DocumentType]] = Field( default_factory=list, desc="The datasets to blend.", hint=FieldHint.core, @@ -243,7 +247,7 @@ def _validate(self) -> None: def build_and_sample( self, sampling: SamplingData, - ) -> SampledDataset[SampleType]: + ) -> SampledDataset[DocumentType]: from fast_llm.data.dataset.blended import BlendedDataset # Build and sample the datasets. @@ -264,37 +268,9 @@ def build_and_sample( for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) ] # Blend the datasets. - return BlendedDataset[SampleType]( + return BlendedDataset[DocumentType]( self.name, sampled_datasets, self.weights, sampling, ) - - -@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class MemmapDatasetConfig[SampleType: Sample](IndexedDatasetConfig[SampleType]): - _abstract: typing.ClassVar[bool] = False - path: pathlib.Path = Field( - default=None, - desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", - hint=FieldHint.core, - ) - - def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]": - name = str(self.path).replace("/", "__") - if self.path.is_file(): - from fast_llm.data.dataset.memmap import MemmapDataset - - return MemmapDataset[SampleType](name, self.path, preprocessing) - elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file(): - logger.warning( - "Using the legacy memmap dataset format." - " This format is deprecated and will be removed in a future release." - " Please recreate the dataset in the new memmap format." - ) - from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset - - return LegacyMemmapDataset[SampleType](name, self.path, preprocessing) - else: - raise FileNotFoundError(self.path) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 5e978ac2b..b66bc5445 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -16,7 +16,7 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset - from fast_llm.data.sample.language_model import LanguageModelSample + from fast_llm.data.document.language_model import LanguageModelDocument @dataclasses.dataclass(kw_only=True) @@ -30,7 +30,7 @@ class GPTSamplingData(SamplingData): @config_class(dynamic_type={SampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): +class GPTRandomDatasetConfig[DocumentType: LanguageModelDocument](SampledDatasetConfig[DocumentType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -38,14 +38,14 @@ class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConf hint=FieldHint.core, ) - def build_and_sample(self, sampling: GPTSamplingData) -> "GPTRandomSampledDataset[SampleType]": + def build_and_sample(self, sampling: GPTSamplingData) -> "GPTRandomSampledDataset[DocumentType]": from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset - return GPTRandomSampledDataset[SampleType](sampling, self.name) + return GPTRandomSampledDataset[DocumentType](sampling, self.name) @config_class(dynamic_type={SampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): +class GPTDatasetFromFileConfig[DocumentType: LanguageModelDocument](SamplableDatasetConfig[DocumentType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -53,22 +53,22 @@ class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDataset hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[DocumentType]: config = self._load_config() return config.build_and_sample(sampling) - def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: + def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[DocumentType]: config = self._load_config() assert isinstance(config, SamplableDatasetConfig) return config.build(preprocessing) - def _load_config(self) -> SampledDatasetConfig[SampleType]: + def _load_config(self) -> SampledDatasetConfig[DocumentType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] - return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(config)) + return SampledDatasetConfig[DocumentType].from_dict(self._convert_paths(config)) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. @@ -159,14 +159,14 @@ class FimConfig(Config): @config_class(dynamic_type={SampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType], FimConfig): +class GPTFimSampledDatasetConfig[DocumentType: LanguageModelDocument](SampledDatasetConfig[DocumentType], FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: SampledDatasetConfig[SampleType] = Field( + dataset: SampledDatasetConfig[DocumentType] = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -175,14 +175,14 @@ class GPTFimSampledDatasetConfig[SampleType: LanguageModelSample](SampledDataset def build_and_sample( self, sampling: GPTSamplingData, - ) -> "GPTFimDataset[SampleType]": + ) -> "GPTFimDataset[DocumentType]": from fast_llm.data.dataset.gpt.fim import GPTFimDataset - return GPTFimDataset[SampleType](self, self.dataset.build_and_sample(sampling), sampling) + return GPTFimDataset[DocumentType](self, self.dataset.build_and_sample(sampling), sampling) @config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): +class GPTTestSlowDatasetConfig[DocumentType: LanguageModelDocument](SampledDatasetConfig[DocumentType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ @@ -195,8 +195,8 @@ class GPTTestSlowDatasetConfig[SampleType: LanguageModelSample](SampledDatasetCo hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[DocumentType]: assert sampling.distributed.config.world_size > 1 if sampling.distributed.config.rank == 0: time.sleep(self.sleep) - return GPTRandomDatasetConfig[SampleType]().build_and_sample(sampling) + return GPTRandomDatasetConfig[DocumentType]().build_and_sample(sampling) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index b70fc8360..55ae7c1f3 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -3,13 +3,13 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.data.sample.token import TokenSample +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.document.token import TokenDocument from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): +class GPTFimDataset[DocumentType: LanguageModelDocument](SampledDataset[DocumentType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -18,7 +18,7 @@ class GPTFimDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]) def __init__( self, config: FimConfig, - dataset: SampledDataset[SampleType], + dataset: SampledDataset[DocumentType], sampling: GPTSamplingData, ): if sampling.preprocessing.use_loss_masking_spans: @@ -43,18 +43,28 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, index: int) -> SampleType: + def __getitem__(self, index: int) -> list[DocumentType]: # TODO: Use torch methods to avoid back and forth. - return LanguageModelSample( - TokenSample( - torch.from_numpy( - self._fim( - self._dataset[index].tokens.tokens.numpy(), - np.random.RandomState(seed=(self._seed + index) % MAX_SEED), + documents = self._dataset[index] + for document in documents: + assert document.loss_masking_spans is None + assert document.chosen_spans is None + assert document.rejected_spans is None + assert document.image_patches is None + + return [ + LanguageModelDocument( + tokens=TokenDocument( + tokens=torch.from_numpy( + self._fim( + document.tokens.tokens.numpy(), + np.random.RandomState(seed=(self._seed + index) % MAX_SEED), + ) ) ) ) - ) + for document in documents + ] @property def name(self) -> str: diff --git a/fast_llm/data/dataset/gpt/legacy_memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py index d29e31596..0b47999b9 100644 --- a/fast_llm/data/dataset/gpt/legacy_memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -5,10 +5,10 @@ import torch from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.document.range import RangeDocument +from fast_llm.data.document.token import TokenDocument from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig -from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.data.sample.range import RangeSample -from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div @@ -25,7 +25,7 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -class LegacyMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): +class LegacyMemmapDataset[DocumentType: LanguageModelDocument](IndexedDataset[DocumentType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -153,7 +153,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - def get_document(self, index: int, begin: int = 0, end: int | None = None) -> SampleType: + def get_document(self, index: int, begin: int = 0, end: int | None = None) -> DocumentType: if end is None: end = self.get_document_size(index) sample_size = self._document_sizes[index].item() @@ -175,29 +175,29 @@ def get_document(self, index: int, begin: int = 0, end: int | None = None) -> Sa assert self._spans is not None if hasattr(self, "_spans"): # Convert to in range format (begin, end). - sample_spans = RangeSample( - [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size + sample_spans = RangeDocument( + ranges=[(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()] ).crop(begin, end) else: - sample_spans = RangeSample([], end - begin) + sample_spans = RangeDocument(ranges=[]) else: sample_spans = None if self._preprocessing.use_preference_spans: # Convert to in range format (begin, end). - chosen_spans = RangeSample( + chosen_spans = RangeDocument( [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], sample_size, ).crop(begin, end) - rejected_spans = RangeSample( + rejected_spans = RangeDocument( [(self._rejected_spans[index][0].item(), self._rejected_spans[index][1].item() + 1)], sample_size, ).crop(begin, end) else: chosen_spans = rejected_spans = None - return LanguageModelSample( - tokens=TokenSample(token_ids), + return LanguageModelDocument( + tokens=TokenDocument(token_ids), loss_masking_spans=sample_spans, chosen_spans=chosen_spans, rejected_spans=rejected_spans, diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 939b900e5..387403e9b 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -3,22 +3,19 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.document.token import TokenDocument from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig -from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type -class GPTRandomSampledDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): +class GPTRandomSampledDataset[DocumentType: LanguageModelDocument](SampledDataset[DocumentType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed self._parameters = sampling.parameters assert isinstance(sampling.preprocessing, LanguageModelPreprocessingConfig) - assert not sampling.preprocessing.use_loss_masking_spans - assert not sampling.preprocessing.use_preference_spans - assert not sampling.preprocessing.use_image_patches self._vocab_size = sampling.preprocessing.vocab_size self._dtype = get_unsigned_integer_type(self._vocab_size).torch @@ -26,19 +23,21 @@ def __init__(self, sampling: GPTSamplingData, name: str): def __len__(self) -> int: return self._parameters.num_samples - def __getitem__(self, index: int) -> SampleType: + def __getitem__(self, index: int) -> list[DocumentType]: # TODO: Sample in self._dtype (breaking) - return LanguageModelSample( - TokenSample( - torch.from_numpy( - np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( - 0, - self._vocab_size, - size=(self._parameters.sequence_length + self._parameters.extra_tokens,), - ) - ).to(self._dtype), + return [ + LanguageModelDocument( + tokens=TokenDocument( + tokens=torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, + self._vocab_size, + size=(self._parameters.sequence_length + self._parameters.extra_tokens,), + ) + ).to(self._dtype), + ) ) - ) + ] @property def name(self) -> str: diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index b2e6f7e3d..af4f72539 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -4,11 +4,11 @@ from fast_llm.data.dataset.abstract import SamplableDataset from fast_llm.data.dataset.config import SamplingData, SamplingParameters -from fast_llm.data.sample.abstract import Sample +from fast_llm.data.document.abstract import Document from fast_llm.utils import Assert, padded_cumsum -class IndexedDataset[SampleType: Sample](SamplableDataset[SampleType]): +class IndexedDataset[DocumentType: Document](SamplableDataset[DocumentType]): """ A dataset containing a list of samples. TODO: Move sampling responsibility here? @@ -31,7 +31,7 @@ def get_document_size(self, index: int) -> int: @abc.abstractmethod def get_document( self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: + ) -> DocumentType: pass def __len__(self) -> int: @@ -55,12 +55,12 @@ def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": return SampledIndexedDataset(self, sampling) -class DatasetSlice[SampleType: Sample](IndexedDataset[SampleType]): +class DatasetSlice[DocumentType: Document](IndexedDataset[DocumentType]): def __init__( self, name: str, - dataset: IndexedDataset[SampleType], + dataset: IndexedDataset[DocumentType], begin: int | None = None, end: int | None = None, ): @@ -86,7 +86,7 @@ def get_document_size(self, index: int) -> int: def get_document( self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: + ) -> DocumentType: """ Get the sample (document) with the given index (in the dataset slice), optionally subsampled to a specific offset (starting point) and maximum length @@ -102,12 +102,12 @@ def name(self) -> str: return self._name -class ConcatenatedDataset[SampleType: Sample](IndexedDataset[SampleType]): +class ConcatenatedDataset[DocumentType: Document](IndexedDataset[DocumentType]): def __init__( self, name: str, - datasets: list[IndexedDataset[SampleType]], + datasets: list[IndexedDataset[DocumentType]], ): self._name = name self._datasets = datasets @@ -134,7 +134,7 @@ def get_document_size(self, index: int) -> int: def get_document( self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: + ) -> DocumentType: dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document( index - self._dataset_splits[dataset].item(), begin, end, parameters diff --git a/fast_llm/data/sample/__init__.py b/fast_llm/data/dataset/memmap/__init__.py similarity index 100% rename from fast_llm/data/sample/__init__.py rename to fast_llm/data/dataset/memmap/__init__.py diff --git a/fast_llm/data/dataset/memmap/abstract.py b/fast_llm/data/dataset/memmap/abstract.py new file mode 100644 index 000000000..6090d188a --- /dev/null +++ b/fast_llm/data/dataset/memmap/abstract.py @@ -0,0 +1,119 @@ +import abc +import io +import pathlib +import typing + +import torch + +from fast_llm.config import Configurable +from fast_llm.data.dataset.memmap.config import ( + MemmapIndexDatasetReaderConfig, + MemmapReaderBaseConfig, + MemmapReaderConfig, + NullReaderConfig, +) +from fast_llm.data.document.abstract import Document +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig +from fast_llm.utils import Assert + + +class MemmapReaderBase[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): + @abc.abstractmethod + def get_document(self, index: int, begin: int, end: int) -> Document | None: + pass + + +class NullMemmapReader[ConfigType: NullReaderConfig](MemmapReaderBase[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> None: + return None + + +class MemmapReader[ConfigType: MemmapReaderConfig](MemmapReaderBase[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config) + # Note: This is the requirement at reading time (ex. from the model), + # which may differ from how the dataset was actually preprocessed (`config.preprocessing`) + # Compatibility checked in `MemmapDataset`. + self._model_preprocessing = NullPreprocessingConfig if model_preprocessing is None else model_preprocessing + buffer_begin = self._config.begin + len(self._config.header) + buffer_end = self._config.end - len(self._config.footer) + Assert.eq(buffer[self._config.begin : buffer_begin].tobytes(), self._config.header) + Assert.eq(buffer[buffer_end : self._config.end].tobytes(), self._config.footer) + self._buffer = buffer[buffer_begin:buffer_end] + + @abc.abstractmethod + def get_document(self, index: int, begin: int, end: int) -> Document: + pass + + +class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): + def __len__(self) -> int: + return len(self._config) + + @property + def num_tokens(self) -> int: + return self._config.num_tokens + + @abc.abstractmethod + def get_document_sizes(self) -> "torch.Tensor": + pass + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + pass + + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + raise NotImplementedError() + + +class MemmapWriter(abc.ABC): + def __init__( + self, stream: io.BufferedWriter | pathlib.Path, preprocessing_config: PreprocessingConfig | None = None + ): + self._owns_stream = isinstance(stream, pathlib.Path) + if self._owns_stream: + stream = stream.open("wb") + self._stream = stream + self._preprocessing_config = ( + NullPreprocessingConfig() if preprocessing_config is None else preprocessing_config + ) + + def __enter__(self): + self._begin = self._stream.tell() + self._stream.write(self._get_config_class().header) + return self + + def write(self, document: Document): + assert hasattr(self, "_begin") and not hasattr(self, "_end") + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._stream.write(self._get_config_class().footer) + self._end = self._stream.tell() + if self._owns_stream: + self._stream.close() + + @classmethod + @abc.abstractmethod + def _get_config_class(cls) -> type[MemmapReaderConfig]: + pass + + def get_config(self, offset: int = 0) -> MemmapReaderConfig: + assert hasattr(self, "_end") + return self._get_config(self._begin + offset, self._end + offset) + + @abc.abstractmethod + def _get_config(self, begin: int, end: int): + pass + + @classmethod + def write_dataset( + cls, + stream: io.BufferedWriter, + documents: typing.Iterable[Document], + preprocessing_config: PreprocessingConfig | None = None, + ) -> MemmapReaderConfig: + with cls(stream, preprocessing_config) as writer: + for document in documents: + writer.write(document) + return writer.get_config() diff --git a/fast_llm/data/dataset/memmap/config.py b/fast_llm/data/dataset/memmap/config.py new file mode 100644 index 000000000..ce5ecb06c --- /dev/null +++ b/fast_llm/data/dataset/memmap/config.py @@ -0,0 +1,462 @@ +import io +import logging +import math +import pathlib +import typing + +import torch + +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig +from fast_llm.data.preprocessing.image_patch import ImagePatchConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert, get_unique + +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.memmap.abstract import ( + MemmapIndexedDatasetReader, + MemmapReader, + MemmapWriter, + NullMemmapReader, + ) + from fast_llm.data.dataset.memmap.language_model import LanguageModelReader, LanguageModelWriter + from fast_llm.data.dataset.memmap.patch import PatchReader, PatchWriter + from fast_llm.data.dataset.memmap.range import RangeReader, RangeWriter + from fast_llm.data.dataset.memmap.token import TokenReader, TokenWriter + from fast_llm.data.document.abstract import Document + +logger = logging.getLogger(__name__) + + +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class MemmapDatasetConfig[DocumentType: Document](IndexedDatasetConfig[DocumentType]): + _abstract: typing.ClassVar[bool] = False + path: pathlib.Path = Field( + default=None, + desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", + hint=FieldHint.core, + ) + + def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[DocumentType]": + name = str(self.path).replace("/", "__") + if self.path.is_file(): + from fast_llm.data.dataset.memmap.memmap import MemmapDataset + + return MemmapDataset[DocumentType](name, self.path, preprocessing) + elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file(): + logger.warning( + "Using the legacy memmap dataset format." + " This format is deprecated and will be removed in a future release." + " Please recreate the dataset in the new memmap format." + ) + from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset + + return LegacyMemmapDataset[DocumentType](name, self.path, preprocessing) + else: + raise FileNotFoundError(self.path) + + +@config_class(registry=True) +class MemmapReaderBaseConfig(Config): + """ + Configuration for a memmap reader or reader-like object. + Note: `MemmapDataset` requires a `MemmapIndexedDatasetReader`. + Other readers need to be nested within a `MemmapIndexedDatasetReader` + Note: Reader configs are not typical configs, and do not need to be located in a separate `config.py` file. + """ + + _abstract = True + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is MemmapReaderBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass, necessary for loading configs where some components could be absent. + return NullReaderConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + def get_reader(self, buffer: memoryview) -> "MemmapReader|None": + raise NotImplementedError() + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, including header and footer. Used for self-validation. + """ + raise NotImplementedError() + + def get_metadata(self) -> dict[str, typing.Any]: + raise NotImplementedError() + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + raise NotImplementedError() + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) +class NullReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a dynamically disabled reader. + """ + + _abstract = False + + def get_reader(self, buffer: memoryview) -> "NullMemmapReader": + from fast_llm.data.dataset.memmap.abstract import NullMemmapReader + + return NullMemmapReader(self) + + @property + def expected_buffer_size(self) -> int: + return 0 + + +@config_class(registry=True) +class MemmapReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a standard memmap reader. + """ + + # Data location in the file. + begin: int = Field() + end: int = Field() + # Constant strings for alignment safety. + header: typing.ClassVar[bytes] + footer: typing.ClassVar[bytes] + # Additional information about how the dataset was prepared. + preprocessing: PreprocessingConfig = Field() + + @property + def reader_class(self) -> "type[MemmapReader]": + raise NotImplementedError() + + def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None) -> "MemmapReader": + return self.reader_class(self, buffer, model_preprocessing) + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, including header and footer. Used for self-validation. + """ + return self._expected_buffer_size + len(self.header) + len(self.footer) + + @property + def _expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, excluding header and footer. Used for self-validation. + """ + raise NotImplementedError() + + @property + def writer_class(self) -> "type[MemmapWriter]": + raise NotImplementedError() + + def get_writer(self, stream: io.BufferedWriter) -> "MemmapWriter": + return self.writer_class(stream) + + def _validate(self): + super()._validate() + Assert.eq(self.end - self.begin, self.expected_buffer_size) + + +@config_class() +class PatchReaderBaseConfig(MemmapReaderBaseConfig): + _abstract = False + patch_shape: tuple[int, ...] = Field() + data_type: DataType = Field() + + @property + def patch_size(self) -> int: + return math.prod(self.patch_shape) + + @property + def grid_dims(self) -> int: + return len(self.patch_shape) - 1 + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "patch"}) +class PatchReaderConfig(PatchReaderBaseConfig, MemmapReaderConfig): + header: typing.ClassVar[bytes] = b"patch begin" + footer: typing.ClassVar[bytes] = b"patch end" + num_documents: int = Field() + num_patches: int = Field() + num_patch_groups: int = Field() + + def __len__(self) -> int: + return self.num_documents + + @property + def reader_class(self) -> "type[PatchReader]": + from fast_llm.data.dataset.memmap.patch import PatchReader + + return PatchReader + + @property + def writer_class(self) -> "type[PatchWriter]": + from fast_llm.data.dataset.memmap.patch import PatchWriter + + return PatchWriter + + @property + def _expected_buffer_size(self) -> int: + return ( + self.num_patches * self.patch_size * self.data_type.torch.itemsize + + ((1 + self.grid_dims) * self.num_patches + self.num_patch_groups + 2 * self.num_documents + 2) + * torch.int32.itemsize + ) + + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_documents": self.num_documents, + "num_patches": self.num_patches, + "num_patch_groups": self.num_patch_groups, + "num_pixels": self.patch_size * self.num_patches, + "patch_shape": self.patch_shape, + "data_type": str(self.data_type), + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "num_patches": sum(metadata_["num_patches"] for metadata_ in metadata), + "num_patch_groups": sum(metadata_["num_patch_groups"] for metadata_ in metadata), + "num_pixels": sum(metadata_["num_pixels"] for metadata_ in metadata), + "patch_shape": get_unique(metadata_["patch_shape"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + } + + +@config_class() +class RangeReaderBaseConfig(MemmapReaderBaseConfig): + _abstract = False + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) +class RangeReaderConfig(RangeReaderBaseConfig, MemmapReaderConfig): + header: typing.ClassVar[bytes] = b"range begin" + footer: typing.ClassVar[bytes] = b"range end" + num_documents: int = Field() + num_ranges: int = Field() + + @property + def reader_class(self) -> "type[RangeReader]": + from fast_llm.data.dataset.memmap.range import RangeReader + + return RangeReader + + @property + def writer_class(self) -> "type[RangeWriter]": + from fast_llm.data.dataset.memmap.range import RangeWriter + + return RangeWriter + + @property + def _expected_buffer_size(self) -> int: + return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize + + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_documents": self.num_documents, + "num_ranges": self.num_ranges, + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "num_ranges": sum(metadata_["num_ranges"] for metadata_ in metadata), + } + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) +class TokenReaderConfig(MemmapReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"token begin" + footer: typing.ClassVar[bytes] = b"token end" + num_documents: int = Field() + num_tokens: int = Field() + data_type: DataType = Field() + + def __len__(self) -> int: + return self.num_documents + + @property + def reader_class(self) -> "type[TokenReader]": + from fast_llm.data.dataset.memmap.token import TokenReader + + return TokenReader + + @property + def writer_class(self) -> "type[TokenWriter]": + from fast_llm.data.dataset.memmap.token import TokenWriter + + return TokenWriter + + @property + def _expected_buffer_size(self) -> int: + return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize + + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_tokens": self.num_tokens, + "num_documents": self.num_documents, + "data_type": str(self.data_type), + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata), + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + } + + +@config_class() +class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): + """ + Configuration for a standard memmap reader matching the indexed dataset interface, i.e., + consisting of a list of documents of known lengths. + """ + + def __len__(self) -> int: + raise NotImplementedError() + + @property + def num_tokens(self) -> int: + raise NotImplementedError() + + @property + def reader_class(self) -> "type[MemmapIndexedDatasetReader]": + raise NotImplementedError() + + def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader": + return self.reader_class(self, buffer, model_preprocessing) + + def get_metadata(self) -> dict[str, typing.Any]: + return {"num_tokens": self.num_tokens} + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return {"num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata)} + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) +class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"lm begin" + footer: typing.ClassVar[bytes] = b"lm end" + tokens: TokenReaderConfig = Field() + # Using dynamic type for optional readers for enabling/disabling + loss_masking_spans: MemmapReaderBaseConfig = Field() + chosen_spans: MemmapReaderBaseConfig = Field() + rejected_spans: MemmapReaderBaseConfig = Field() + image_patches: MemmapReaderBaseConfig = Field() + + def _validate(self) -> None: + super()._validate() + if isinstance(self.preprocessing, NullPreprocessingConfig): + # Address missing config, mostly for backward compatibility. + # TODO: We can't tell which dataset this comes from. + logger.warning( + f"Preprocessing configuration not specified for dataset reader, generating partial configuration from known parameters." + ) + if isinstance(self.image_patches, PatchReaderConfig): + Assert.eq(len(patch_shape := self.image_patches.patch_shape), 3) + image_patches = ImagePatchConfig(height=patch_shape[1], width=patch_shape[2]) + else: + image_patches = NullPreprocessingConfig() + self.preprocessing = LanguageModelPreprocessingConfig( + image_patches=image_patches, + use_loss_masking_spans=isinstance(self.loss_masking_spans, RangeReaderConfig), + use_preference_spans=isinstance(self.chosen_spans, RangeReaderConfig), + ) + # TODO: Avoid duplicated information. + Assert.custom( + isinstance, + self.loss_masking_spans, + RangeReaderConfig if self.preprocessing.use_loss_masking_spans else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.chosen_spans, + RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.rejected_spans, + RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, + ) + if self.preprocessing.use_image_patches: + Assert.custom(isinstance, self.image_patches, PatchReaderConfig) + Assert.eq(self.image_patches.patch_shape, self.preprocessing.image_patches.patch_shape) + Assert.eq(self.image_patches.data_type, DataType.uint8) + else: + Assert.custom(isinstance, self.image_patches, NullReaderConfig) + + def __len__(self) -> int: + return len(self.tokens) + + @property + def num_tokens(self) -> int: + return self.tokens.num_tokens + + @property + def reader_class(self) -> "type[LanguageModelReader]": + from fast_llm.data.dataset.memmap.language_model import LanguageModelReader + + return LanguageModelReader + + @property + def writer_class(self) -> "type[LanguageModelWriter]": + from fast_llm.data.dataset.memmap.language_model import LanguageModelWriter + + return LanguageModelWriter + + @property + def _expected_buffer_size(self) -> int: + return ( + self.tokens.expected_buffer_size + + self.loss_masking_spans.expected_buffer_size + + self.chosen_spans.expected_buffer_size + + self.rejected_spans.expected_buffer_size + + self.image_patches.expected_buffer_size + ) + + def get_metadata(self) -> dict[str, typing.Any]: + out = super().get_metadata() + out["tokens"] = self.tokens.get_metadata() + if not isinstance(self.loss_masking_spans, NullReaderConfig): + out["loss_masking_spans"] = self.loss_masking_spans.get_metadata() + if not isinstance(self.chosen_spans, NullReaderConfig): + out["chosen_spans"] = self.chosen_spans.get_metadata() + if not isinstance(self.rejected_spans, NullReaderConfig): + out["rejected_spans"] = self.rejected_spans.get_metadata() + if not isinstance(self.image_patches, NullReaderConfig): + out["image_patches"] = self.image_patches.get_metadata() + return out + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + out = super().blend_metadata(metadata) + out["tokens"] = TokenReaderConfig.blend_metadata([metadata_["tokens"] for metadata_ in metadata]) + if "loss_masking_spans" in metadata[0]: + out["loss_masking_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["loss_masking_spans"] for metadata_ in metadata] + ) + if "chosen_spans" in metadata[0]: + out["chosen_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["chosen_spans"] for metadata_ in metadata] + ) + if "rejected_spans" in metadata[0]: + out["image_patches"] = RangeReaderConfig.blend_metadata( + [metadata_["image_patches"] for metadata_ in metadata] + ) + if "image_patches" in metadata[0]: + out["image_patches"] = PatchReaderConfig.blend_metadata( + [metadata_["image_patches"] for metadata_ in metadata] + ) + return out diff --git a/fast_llm/data/dataset/memmap/language_model.py b/fast_llm/data/dataset/memmap/language_model.py new file mode 100644 index 000000000..34d71eba3 --- /dev/null +++ b/fast_llm/data/dataset/memmap/language_model.py @@ -0,0 +1,237 @@ +import io +import pathlib +import tempfile +import typing + +import torch + +from fast_llm.data.dataset.memmap.abstract import MemmapIndexedDatasetReader, MemmapWriter +from fast_llm.data.dataset.memmap.config import LanguageModelReaderConfig, NullReaderConfig +from fast_llm.data.dataset.memmap.patch import PatchReader, PatchWriter +from fast_llm.data.dataset.memmap.range import RangeReader, RangeWriter +from fast_llm.data.dataset.memmap.token import TokenWriter +from fast_llm.data.document.abstract import Document +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.utils import Assert + + +class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + _model_preprocessing: LanguageModelPreprocessingConfig + + def __init__( + self, + config: ConfigType, + buffer: memoryview, + model_preprocessing: LanguageModelPreprocessingConfig | None = None, + ): + super().__init__(config, buffer, model_preprocessing) + self._config.preprocessing.check_compatibility(self._model_preprocessing) + # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. + self._tokens = self._config.tokens.get_reader(buffer) + null_reader = NullReaderConfig().get_reader(buffer) + self._loss_masking_spans = ( + self._config.loss_masking_spans.get_reader(buffer) + if self._model_preprocessing.use_loss_masking_spans + else null_reader + ) + self._chosen_spans = ( + self._config.chosen_spans.get_reader(buffer) + if self._model_preprocessing.use_preference_spans + else null_reader + ) + self._rejected_spans = ( + self._config.rejected_spans.get_reader(buffer) + if self._model_preprocessing.use_preference_spans + else null_reader + ) + self._image_patches = ( + self._config.image_patches.get_reader(buffer) + if self._model_preprocessing.use_image_patches + else null_reader + ) + # TODO: Make this configurable. (Add to `model_preprocessing`?) + self._image_normalization_config = ImageNormalizationConfig() + + @property + def num_tokens(self) -> int: + return self._config.tokens.num_tokens + + def get_document(self, index: int, begin: int, end: int) -> Document: + if self._model_preprocessing.use_image_patches: + image_patches = self._image_patches.get_document(index, begin, end) + if image_patches is not None: + image_patches.patches = self._image_normalization_config.normalize(image_patches.patches) + else: + image_patches = None + return LanguageModelDocument( + tokens=self._tokens.get_document(index, begin, end), + loss_masking_spans=self._loss_masking_spans.get_document(index, begin, end), + chosen_spans=self._chosen_spans.get_document(index, begin, end), + rejected_spans=self._rejected_spans.get_document(index, begin, end), + image_patches=image_patches, + ) + + def get_document_sizes(self) -> torch.Tensor: + return self._tokens.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._tokens.get_document_size(index) + + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + begin_index, end_index, token_metadata = self._tokens.get_split(begin_ratio, end_ratio) + metadata = { + "num_tokens": token_metadata["num_tokens"], + "tokens": token_metadata, + } + if isinstance(self._loss_masking_spans, RangeReader): + metadata["loss_masking_spans"] = self._loss_masking_spans.get_split(begin_index, end_index) + if isinstance(self._chosen_spans, RangeReader): + metadata["chosen_spans"] = self._chosen_spans.get_split(begin_index, end_index) + if isinstance(self._rejected_spans, RangeReader): + metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index) + if isinstance(self._image_patches, PatchReader): + metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index) + + return begin_index, end_index, metadata + + +class LanguageModelWriter(MemmapWriter): + _preprocessing_config: LanguageModelPreprocessingConfig + + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + + self._directory = tempfile.TemporaryDirectory() + self._path = pathlib.Path(self._directory.name) + # We write intermediate results in separate files so we don't need to iterate over the dataset multiple times. + self._token_writer = TokenWriter(self._path.joinpath("tokens")).__enter__() + if self._preprocessing_config.use_loss_masking_spans: + self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() + if self._preprocessing_config.use_preference_spans: + self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() + self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() + if self._preprocessing_config.use_image_patches: + self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() + return self + + def write(self, document: LanguageModelDocument): + super().write(document) + # Write tokens. + self._token_writer.write(document.tokens) + + # Write loss masking spans. + if self._preprocessing_config.use_loss_masking_spans: + assert document.loss_masking_spans is not None + self._loss_masking_span_writer.write(document.loss_masking_spans) + + # Write preference spans. + if self._preprocessing_config.use_preference_spans: + assert document.chosen_spans is not None + assert document.rejected_spans is not None + self._chosen_spans_writer.write(document.chosen_spans) + self._rejected_spans_writer.write(document.rejected_spans) + + # Write image patches + if self._preprocessing_config.use_image_patches: + assert document.image_patches is not None + self._image_patches_writer.write(document.image_patches) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._token_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_loss_masking_spans: + self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_preference_spans: + self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) + self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_image_patches: + self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) + + if exc_type is None: + # A dummy config so we can verify the begin and end offsets. + config = self._get_config(self._begin, None) + _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) + + if self._preprocessing_config.use_loss_masking_spans: + _copy_chunked( + self._path.joinpath("loss_masking_spans"), + self._stream, + config.loss_masking_spans.begin, + config.loss_masking_spans.end, + ) + if self._preprocessing_config.use_preference_spans: + _copy_chunked( + self._path.joinpath("chosen_spans"), + self._stream, + config.chosen_spans.begin, + config.chosen_spans.end, + ) + _copy_chunked( + self._path.joinpath("rejected_spans"), + self._stream, + config.rejected_spans.begin, + config.rejected_spans.end, + ) + + if self._preprocessing_config.use_image_patches: + _copy_chunked( + self._path.joinpath("image_patches"), + self._stream, + config.image_patches.begin, + config.image_patches.end, + ) + + self._directory.cleanup() + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[LanguageModelReaderConfig]: + return LanguageModelReaderConfig + + def _get_config(self, begin: int, end: int | None): + tokens = self._token_writer.get_config(begin + len(LanguageModelReaderConfig.header)) + offset = tokens.end + if self._preprocessing_config.use_loss_masking_spans: + loss_masking_spans = self._loss_masking_span_writer.get_config(offset) + offset = loss_masking_spans.end + else: + loss_masking_spans = NullReaderConfig() + if self._preprocessing_config.use_preference_spans: + chosen_spans = self._chosen_spans_writer.get_config(offset) + offset = chosen_spans.end + rejected_spans = self._rejected_spans_writer.get_config(offset) + offset = rejected_spans.end + else: + chosen_spans = NullReaderConfig() + rejected_spans = NullReaderConfig() + if self._preprocessing_config.use_image_patches: + image_patches = self._image_patches_writer.get_config(offset) + offset = image_patches.end + else: + image_patches = NullReaderConfig() + + if end is None: + end = offset + len(LanguageModelReaderConfig.footer) + + return LanguageModelReaderConfig( + begin=begin, + end=end, + tokens=tokens, + loss_masking_spans=loss_masking_spans, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + image_patches=image_patches, + preprocessing=self._preprocessing_config, + ) + + +def _copy_chunked(path: pathlib.Path, stream: io.BufferedWriter, expected_begin: int, expected_end: int): + # Copy temporary file content in chunks of 100 MB. + Assert.eq(stream.tell(), expected_begin) + with path.open("rb") as input_stream: + while data := input_stream.read(100000000): + stream.write(data) + Assert.eq(stream.tell(), expected_end) diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap/memmap.py similarity index 85% rename from fast_llm/data/dataset/memmap.py rename to fast_llm/data/dataset/memmap/memmap.py index e571fc433..49172e845 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap/memmap.py @@ -7,18 +7,17 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import ( - MemmapIndexDatasetReaderConfig, - MemmapIndexedDatasetReader, - MemmapWriter, - Sample, +from fast_llm.data.dataset.memmap.abstract import MemmapIndexedDatasetReader, MemmapWriter +from fast_llm.data.dataset.memmap.config import MemmapIndexDatasetReaderConfig +from fast_llm.data.document.abstract import ( + Document, ) +from fast_llm.data.preprocessing.abstract import PreprocessingConfig FILE_HEADER = b"fast_llm_prepared_dataset" -class MemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): +class MemmapDataset[DocumentType: Document](IndexedDataset[DocumentType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset. """ @@ -28,12 +27,6 @@ def read_reader_config(path: pathlib.Path | str) -> MemmapIndexDatasetReaderConf """ Read the MemmapIndexDatasetReaderConfig from a memmap file. """ - # Import reader configs to register them in the dynamic class registry - from fast_llm.data.sample.language_model import LanguageModelReaderConfig # noqa: F401 - from fast_llm.data.sample.patch import PatchReaderConfig # noqa: F401 - from fast_llm.data.sample.range import RangeReaderConfig # noqa: F401 - from fast_llm.data.sample.token import TokenReaderConfig # noqa: F401 - path = pathlib.Path(path) if isinstance(path, str) else path with path.open("rb") as stream: # Verify file type. @@ -78,7 +71,7 @@ def __del__(self): def get_document( self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: + ) -> DocumentType: if end is None: end = self._reader.get_document_size(index) return self._reader.get_document(index, begin, end) @@ -108,11 +101,11 @@ def reader(self) -> MemmapIndexedDatasetReader: def write_dataset( cls, path: pathlib.Path, - documents: typing.Iterable[Sample], + documents: typing.Iterable[Document], writer_class: type[MemmapWriter], preprocessing_config: PreprocessingConfig | None = None, ) -> MemmapIndexDatasetReaderConfig: - # TODO: Match `writer_class` with `SampleType`? + # TODO: Match `writer_class` with `DocumentType`? path.parent.mkdir(parents=True, exist_ok=True) with path.open("wb") as stream: # Write the file type header. diff --git a/fast_llm/data/dataset/memmap/patch.py b/fast_llm/data/dataset/memmap/patch.py new file mode 100644 index 000000000..2b551dbbf --- /dev/null +++ b/fast_llm/data/dataset/memmap/patch.py @@ -0,0 +1,141 @@ +import typing + +import numpy as np +import torch + +from fast_llm.data.dataset.memmap.abstract import MemmapReader, MemmapWriter +from fast_llm.data.dataset.memmap.config import PatchReaderConfig +from fast_llm.data.document.abstract import Document +from fast_llm.data.document.patch import PatchDocument, filter_lengths +from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert + + +class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) + self._patches = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_patches * self._config.patch_size, + ).view(self._config.num_patches, *self._config.patch_shape) + offset = self._patches.nbytes + self._token_map = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_patches, + offset=offset, + ) + offset += self._token_map.nbytes + self._positions = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_patches * self._config.grid_dims, + offset=offset, + ).view(self._config.num_patches, self._config.grid_dims) + offset += self._positions.nbytes + self._patch_count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_documents + 1, + offset=offset, + ) + offset += self._patch_count_cumsums.nbytes + self._group_lengths = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_patch_groups, + offset=offset, + ) + offset += self._group_lengths.nbytes + self._group_count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_documents + 1, + offset=offset, + ) + + def get_document(self, index: int, begin: int, end: int) -> Document: + token_map = self._token_map[ + token_slice := slice(self._patch_count_cumsums[index], self._patch_count_cumsums[index + 1]) + ] + patch_filter = (token_map >= begin) & (token_map < end) + return PatchDocument( + patches=self._patches[token_slice][patch_filter], + token_map=token_map[patch_filter] - begin, + positions=self._positions[token_slice][patch_filter], + lengths=filter_lengths( + self._group_lengths[self._group_count_cumsums[index] : self._group_count_cumsums[index + 1]].tolist(), + patch_filter, + ), + ) + + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) + num_patches = self._patch_count_cumsums[end_index].item() - self._patch_count_cumsums[begin_index].item() + return { + "num_documents": end_index - begin_index, + "num_patches": num_patches, + "num_patch_groups": self._group_count_cumsums[end_index].item() + - self._group_count_cumsums[begin_index].item(), + "num_pixels": self._config.patch_size * num_patches, + "patch_shape": self._config.patch_shape, + "data_type": str(self._config.data_type), + } + + +class PatchWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._patch_count_cumsum = [0] + self._group_count_cumsum = [0] + self._token_map = [] + self._positions = [] + self._group_lengths = [] + self._data_type = None + self._patch_shape = None + return self + + def write(self, document: PatchDocument): + super().write(document) + if self._data_type is None: + self._data_type = document.patches.dtype + else: + Assert.eq(self._data_type, document.patches.dtype) + if self._patch_shape is None: + self._patch_shape = tuple(document.patches.shape[1:]) + else: + Assert.eq(self._patch_shape, document.patches.shape[1:]) + self._stream.write(document.patches.numpy().tobytes()) + self._token_map.extend(document.token_map) + self._positions.extend(document.positions) + self._patch_count_cumsum.append(self._patch_count_cumsum[-1] + len(document.patches)) + self._group_count_cumsum.append(self._group_count_cumsum[-1] + len(document.lengths)) + self._group_lengths.extend(document.lengths) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + Assert.lt(self._patch_count_cumsum[-1], np.iinfo(np.int32).max) + self._stream.write(np.array(self._token_map, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._positions, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._patch_count_cumsum, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._group_lengths, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._group_count_cumsum, dtype=np.int32).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[PatchReaderConfig]: + return PatchReaderConfig + + def _get_config(self, begin: int, end: int): + return PatchReaderConfig( + begin=begin, + end=end, + num_documents=len(self._patch_count_cumsum) - 1, + num_patches=self._patch_count_cumsum[-1], + num_patch_groups=self._group_count_cumsum[-1], + patch_shape=self._patch_shape, + data_type=DataType.from_torch(self._data_type), + preprocessing=self._preprocessing_config, + ) diff --git a/fast_llm/data/dataset/memmap/range.py b/fast_llm/data/dataset/memmap/range.py new file mode 100644 index 000000000..9bd1a3119 --- /dev/null +++ b/fast_llm/data/dataset/memmap/range.py @@ -0,0 +1,73 @@ +import typing + +import numpy as np +import torch + +from fast_llm.data.dataset.memmap.abstract import MemmapReader, MemmapWriter +from fast_llm.data.dataset.memmap.config import RangeReaderConfig +from fast_llm.data.document.abstract import Document +from fast_llm.data.document.range import RangeDocument +from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.utils import Assert + + +class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) + self._ranges = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_ranges * 2, + ).view(-1, 2) + self._count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_documents + 1, + offset=self._ranges.nbytes, + ) + + def get_document(self, index: int, begin: int, end: int) -> Document: + sample_size = end - begin + cropped_ranges = ( + (max(begin_ - begin, 0), min(end_ - begin, sample_size)) + for begin_, end_ in self._ranges[self._count_cumsums[index] : self._count_cumsums[index + 1]].tolist() + ) + return RangeDocument(ranges=[(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_]) + + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) + return { + "num_documents": end_index - begin_index, + "num_ranges": self._count_cumsums[end_index].item() - self._count_cumsums[begin_index].item(), + } + + +class RangeWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._count_cumsum = [0] + return self + + def write(self, document: RangeDocument): + super().write(document) + self._stream.write(np.array(document.ranges, dtype=np.int32).tobytes(order="C")) + self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + Assert.lt(self._count_cumsum[-1], np.iinfo(np.int32).max) + self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[RangeReaderConfig]: + return RangeReaderConfig + + def _get_config(self, begin: int, end: int): + return RangeReaderConfig( + begin=begin, + end=end, + num_documents=len(self._count_cumsum) - 1, + num_ranges=self._count_cumsum[-1], + preprocessing=self._preprocessing_config, + ) diff --git a/fast_llm/data/dataset/memmap/token.py b/fast_llm/data/dataset/memmap/token.py new file mode 100644 index 000000000..7d4bcbc39 --- /dev/null +++ b/fast_llm/data/dataset/memmap/token.py @@ -0,0 +1,95 @@ +import typing + +import numpy as np +import torch + +from fast_llm.data.dataset.memmap.abstract import MemmapIndexedDatasetReader, MemmapWriter +from fast_llm.data.dataset.memmap.config import TokenReaderConfig +from fast_llm.data.document.abstract import Document +from fast_llm.data.document.token import TokenDocument +from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert + + +class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) + self._tokens = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_tokens, + ) + self._size_cumsums = torch.frombuffer( + self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._tokens.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Document: + begin_ = self._size_cumsums[index].item() + # Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues. + # Convert begin and end to int to avoid numpy dtype overflow when adding to begin_ + return TokenDocument(tokens=self._tokens[begin_ + begin : begin_ + end].to(torch.int64)) + + def get_document_sizes(self) -> torch.Tensor: + return self._size_cumsums[1:] - self._size_cumsums[:-1] + + def get_document_size(self, index: int) -> int: + return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() + + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1]) + begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens) + end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens) + + return ( + begin_index, + end_index, + { + "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), + "num_documents": end_index - begin_index, + "data_type": str(self._config.data_type), + }, + ) + + +def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: + left = torch.searchsorted(cumsum, value, side="right") + if left == len(cumsum): + return left.item() + return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + + +class TokenWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + return self + + def write(self, document: TokenDocument): + super().write(document) + if self._data_type is None: + self._data_type = document.tokens.dtype + else: + Assert.eq(self._data_type, document.tokens.dtype) + self._stream.write(document.tokens.numpy().tobytes()) + self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[TokenReaderConfig]: + return TokenReaderConfig + + def _get_config(self, begin: int, end: int): + return TokenReaderConfig( + begin=begin, + end=end, + num_documents=len(self._size_cumsum) - 1, + num_tokens=self._size_cumsum[-1], + data_type=DataType.from_torch(self._data_type), + preprocessing=self._preprocessing_config, + ) diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 01f3195e4..ab4af957b 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -2,7 +2,7 @@ import time from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.sample.abstract import Sample +from fast_llm.data.document.abstract import Document try: from fast_llm.csrc.data import build_blending_indices # noqa @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class DatasetMonitor[SampleType: Sample](SampledDataset[SampleType]): +class DatasetMonitor[DocumentType: Document](SampledDataset[DocumentType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -24,7 +24,7 @@ class DatasetMonitor[SampleType: Sample](SampledDataset[SampleType]): def __init__( self, - dataset: SampledDataset[SampleType], + dataset: SampledDataset[DocumentType], data_sample_warn_time_ms: float, ): self._dataset = dataset @@ -33,7 +33,7 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, index: int) -> SampleType: + def __getitem__(self, index: int) -> list[DocumentType]: start_time = time.perf_counter() try: sample = self._dataset[index] diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 36b52d9f8..a3b7c05a5 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -11,7 +11,7 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.sample.abstract import Sample +from fast_llm.data.document.abstract import Document from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -66,14 +66,14 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): +class SampledIndexedDataset[DocumentType: Document](SampledDataset[DocumentType]): """ A sampled dataset. """ def __init__( self, - indexed_dataset: IndexedDataset[SampleType], + indexed_dataset: IndexedDataset[DocumentType], sampling: SamplingData, ): self._indexed_dataset = indexed_dataset @@ -126,17 +126,17 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._parameters.sequence_length + 1 + long_docs_filter = document_sizes > self._parameters.total_length ignored_documents = long_docs_filter.sum().item() if ignored_documents: log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", + f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.total_length} tokens and will be ignored.", log_fn=logger.warning, ) tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() if tokens_per_epoch == 0: raise RuntimeError( - f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + f" > No documents shorter than {self._parameters.total_length} tokens found in dataset {self._indexed_dataset.name}." ) # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, @@ -148,10 +148,7 @@ def _sample(self) -> None: / tokens_per_epoch ) else: - num_epochs = math.ceil( - ((self._parameters.sequence_length + self._parameters.extra_tokens) * self._parameters.num_samples) - / tokens_per_epoch - ) + num_epochs = math.ceil((self._parameters.total_length * self._parameters.num_samples) / tokens_per_epoch) # Prepare for shuffling. generator = torch.Generator(device=self._device) @@ -320,14 +317,12 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - else: # TODO: dynamically handle int64 or int32 in CPP out = build_padded_token_cumsum( - sizes.cpu().numpy(), (self._parameters.sequence_length + 1), TOKEN_CUMSUM_RATE, offset + sizes.cpu().numpy(), self._parameters.total_length, TOKEN_CUMSUM_RATE, offset ) num_tokens = out[-1] out = out[:-1][ : np.clip( - np.searchsorted( - out, self._parameters.num_samples * (self._parameters.sequence_length + 1), side="right" - ), + np.searchsorted(out, self._parameters.num_samples * self._parameters.total_length, side="right"), 0, None, ) @@ -337,7 +332,7 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - def __len__(self) -> int: return self._parameters.num_samples - def __getitem__(self, index: int) -> SampleType: + def __getitem__(self, index: int) -> list[DocumentType]: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. @@ -347,13 +342,10 @@ def __getitem__(self, index: int) -> SampleType: # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample - sample_length = ( - self._parameters.sequence_length - if self._truncate_documents - else self._parameters.sequence_length + self._parameters.extra_tokens + token_start = index * ( + self._parameters.sequence_length if self._truncate_documents else self._parameters.total_length ) - token_start = index * sample_length - token_end = token_start + self._parameters.sequence_length + self._parameters.extra_tokens + token_end = token_start + self._parameters.total_length if token_start < self._unshuffled_tokens: token_start_array = self._token_cumsum_unshuffled.array @@ -369,7 +361,7 @@ def __getitem__(self, index: int) -> SampleType: token_count = token_start_array[token_start_cumsum_index].item() - documents: list[SampleType] = [] + documents: list[DocumentType] = [] while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -380,16 +372,15 @@ def __getitem__(self, index: int) -> SampleType: document_size = self._indexed_dataset.get_document_size(document_index) if not self._truncate_documents: - if document_size > self._parameters.sequence_length + 1: + if document_size > self._parameters.total_length: # Document too long, ignore document_sampling_index += 1 continue - tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample > self._parameters.sequence_length + 1: + tokens_in_sample = token_count % self._parameters.total_length + if document_size + tokens_in_sample > self._parameters.total_length: # Document belongs to the next sample, need to account for padding. - padding_size = self._parameters.sequence_length + 1 - tokens_in_sample + padding_size = self._parameters.total_length - tokens_in_sample if token_count > token_start: - documents.append(documents[-1].get_padding(padding_size)) Assert.eq(token_count + padding_size, token_end) break else: @@ -413,8 +404,7 @@ def __getitem__(self, index: int) -> SampleType: # Go to the next document. document_sampling_index += 1 token_count += document_size - - return documents[0].from_documents(documents) + return documents @property def name(self) -> str: diff --git a/fast_llm/data/document/__init__.py b/fast_llm/data/document/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/document/abstract.py b/fast_llm/data/document/abstract.py new file mode 100644 index 000000000..eb6accfdc --- /dev/null +++ b/fast_llm/data/document/abstract.py @@ -0,0 +1,23 @@ +import abc +import dataclasses + + +@dataclasses.dataclass(kw_only=True) +class Document(abc.ABC): + pass + + +@dataclasses.dataclass(kw_only=True) +class Batch(Document): + pass + # @classmethod + # @abc.abstractmethod + # def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + # pass + + # @abc.abstractmethod + # def crop(self, begin: int, end: int) -> typing.Self: + # pass + + # def to_device_(self, device: "torch.device | str"): + # pass diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py new file mode 100644 index 000000000..23de0605b --- /dev/null +++ b/fast_llm/data/document/language_model.py @@ -0,0 +1,90 @@ +import dataclasses +import logging +import typing + +import torch + +from fast_llm.data.document.abstract import Batch, Document +from fast_llm.data.document.patch import PatchBatch, PatchDocument +from fast_llm.data.document.range import RangeBatch, RangeDocument +from fast_llm.data.document.token import TokenBatch, TokenDocument +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(kw_only=True) +class LanguageModelDocument(Document): + tokens: TokenDocument + loss_masking_spans: RangeDocument | None = None + chosen_spans: RangeDocument | None = None + rejected_spans: RangeDocument | None = None + image_patches: PatchDocument | None = None + + def __len__(self) -> int: + return len(self.tokens) + + +@dataclasses.dataclass(kw_only=True) +class LanguageModelBatch(LanguageModelDocument, Batch): + tokens: TokenBatch + loss_masking_spans: RangeBatch | None = None + chosen_spans: RangeBatch | None = None + rejected_spans: RangeBatch | None = None + image_patches: PatchBatch | None = None + num_tokens: int # Number of tokens in the micro-batch excluding padding at the end. + + @classmethod + def from_documents( + cls, documents: typing.Iterable[LanguageModelDocument], pad_to_size: int | None = None + ) -> typing.Self: + num_tokens = sum(len(document) for document in documents) + if pad_to_size is not None: + Assert.geq(pad_to_size, num_tokens) + padding = pad_to_size - num_tokens + if padding > 0: + documents = documents + [ + LanguageModelDocument( + tokens=TokenDocument(tokens=documents[0].tokens.tokens.new_full([padding], -100)) + ) + ] + sizes = [len(document) for document in documents] + return cls( + tokens=TokenBatch.from_documents([document.tokens for document in documents]), + loss_masking_spans=RangeBatch.from_documents( + [document.loss_masking_spans for document in documents], sizes + ), + chosen_spans=RangeBatch.from_documents([document.chosen_spans for document in documents], sizes), + rejected_spans=RangeBatch.from_documents([document.rejected_spans for document in documents], sizes), + image_patches=PatchBatch.from_documents([document.image_patches for document in documents], sizes), + num_tokens=num_tokens, + ) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + tokens=self.tokens.crop(begin, end), + loss_masking_spans=_crop_optional(self.loss_masking_spans, begin, end), + chosen_spans=_crop_optional(self.chosen_spans, begin, end), + rejected_spans=_crop_optional(self.rejected_spans, begin, end), + image_patches=_crop_optional(self.image_patches, begin, end), + num_tokens=min(end, self.num_tokens) - begin, + ) + + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) + if self.image_patches is not None: + self.image_patches.to_device_(device) + + +def _merge_optional[T](fn: typing.Callable, args: typing.Iterable) -> T | None: + return None if any(arg is None for arg in args) else fn(args) + + +def _crop_optional[T: Document](sample: T, begin: int, end: int) -> T | None: + return None if sample is None else sample.crop(begin, end) diff --git a/fast_llm/data/document/patch.py b/fast_llm/data/document/patch.py new file mode 100644 index 000000000..64bc2841b --- /dev/null +++ b/fast_llm/data/document/patch.py @@ -0,0 +1,66 @@ +import dataclasses +import typing + +import torch + +from fast_llm.data.document.abstract import Batch, Document +from fast_llm.utils import Assert, padded_cumsum + + +def filter_lengths(lengths: list[int], filter: torch.Tensor) -> list[int]: + length_cumsum = padded_cumsum(lengths) + filtered_lengths = (filter[begin:end].sum().item() for begin, end in zip(length_cumsum[:-1], length_cumsum[1:])) + return [length for length in filtered_lengths if length > 0] + + +@dataclasses.dataclass(kw_only=True) +class PatchDocument(Document): + """ + A reusable component holding a set of fixed-shape patches (ex. images, audio, video), + each of which providing a single token embedding in a multimodal model. + """ + + patches: torch.Tensor + token_map: torch.Tensor + positions: torch.Tensor # Position identifier for each patch in the patch grid. + lengths: list[int] # Length of each patch group (ex. image) in the document. TODO: Use cumsums instead? + + def __post_init__(self): + Assert.eq(self.positions.shape, (self.patches.size(0), self.patches.ndim - 2)) + Assert.eq(sum(self.lengths), len(self.patches)) + + +@dataclasses.dataclass(kw_only=True) +class PatchBatch(PatchDocument, Batch): + @classmethod + def from_documents(cls, documents: typing.Iterable[PatchDocument], sizes: typing.Iterable[int]) -> typing.Self: + document_begin = 0 + embedding_maps = [] + for document, size in zip(documents, sizes, strict=True): + if document is not None: + embedding_maps.append(document.token_map + document_begin) + document_begin += size + return ( + cls( + patches=torch.cat([document.patches for document in documents if document is not None]), + token_map=torch.cat(embedding_maps), + positions=torch.cat([document.positions for document in documents if document is not None]), + lengths=sum((document.lengths for document in documents if document is not None), []), + ) + if embedding_maps + else None + ) + + def crop(self, begin: int, end: int) -> typing.Self: + patch_filter = (self.token_map >= begin) & (self.token_map < end) + return self.__class__( + patches=self.patches[patch_filter], + token_map=self.token_map[patch_filter] - begin, + positions=self.positions[patch_filter], + lengths=filter_lengths(self.lengths, patch_filter), + ) + + def to_device_(self, device: "torch.device | str"): + self.patches = self.patches.to(device, non_blocking=True) + self.token_map = self.token_map.to(device, non_blocking=True) + self.positions = self.positions.to(device, non_blocking=True) diff --git a/fast_llm/data/document/range.py b/fast_llm/data/document/range.py new file mode 100644 index 000000000..27efe50fc --- /dev/null +++ b/fast_llm/data/document/range.py @@ -0,0 +1,37 @@ +import dataclasses +import typing + +from fast_llm.data.document.abstract import Batch, Document + + +@dataclasses.dataclass(kw_only=True) +class RangeDocument(Document): + """ + A reusable component holding a set of ranges in a sample. + """ + + ranges: list[tuple[int, int]] + + +@dataclasses.dataclass(kw_only=True) +class RangeBatch(RangeDocument, Batch): + @classmethod + def from_documents( + cls, documents: typing.Iterable[RangeDocument | None], sizes: typing.Iterable[int] + ) -> typing.Self: + """ + Used to merge ranges from multiple documents, i.e. when multiple documents are packed together. + """ + document: RangeDocument + ranges = [] + document_begin = 0 + for document, size in zip(documents, sizes, strict=True): + if document is not None: + for begin, end in document.ranges: + ranges.append((begin + document_begin, end + document_begin)) + document_begin += size + return cls(ranges=ranges) if ranges else None + + def crop(self, begin: int, end: int) -> typing.Self: + cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges) + return self.__class__(ranges=[(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_]) diff --git a/fast_llm/data/document/token.py b/fast_llm/data/document/token.py new file mode 100644 index 000000000..529068170 --- /dev/null +++ b/fast_llm/data/document/token.py @@ -0,0 +1,105 @@ +import dataclasses +import typing + +import torch + +from fast_llm.data.document.abstract import Batch, Document +from fast_llm.utils import Assert, padded_cumsum + + +def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: + if len(lengths) == 1: + # Shortcut for the frequent case of a single document. + return [end - begin] + begin_ = 0 + lengths_ = [] + for length in lengths: + end_ = begin_ + length + cropped_length = min(end_, end) - max(begin_, begin) + if cropped_length > 0: + lengths_.append(cropped_length) + if end_ > end: + break + begin_ = end_ + return lengths_ + + +@dataclasses.dataclass(kw_only=True) +class TokenDocument(Document): + tokens: torch.Tensor + + def __len__(self) -> int: + return len(self.tokens) + + +@dataclasses.dataclass(kw_only=True) +class TokenBatch(TokenDocument, Batch): + lengths: list[int] + sequence_k_past: int = 0 + current_document_begin: int = 0 + + def __post_init__(self): + Assert.eq(sum(self.lengths), len(self.tokens)) + + @classmethod + def from_documents(cls, documents: typing.Iterable[TokenDocument]) -> typing.Self: + return cls( + tokens=torch.cat([document.tokens for document in documents]), + lengths=[len(document) for document in documents], + ) + + def crop(self, begin: int, end: int) -> typing.Self: + Assert.eq(self.sequence_k_past, self.current_document_begin, 0) + + document_begin = 0 + lengths_ = [] + current_document_begin = None + for length in self.lengths: + document_end = document_begin + length + cropped_length = min(document_end, end) - max(document_begin, begin) + if cropped_length > 0: + lengths_.append(cropped_length) + if not current_document_begin: + current_document_begin = document_begin + if document_end > end: + break + document_begin = document_end + + return self.__class__( + tokens=self.tokens[begin:end], + lengths=lengths_, + sequence_k_past=begin, + current_document_begin=current_document_begin, + ) + + def to_device_(self, device: "torch.device | str"): + # Also standardize the dtype while we're here. + self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) + + def get_cumulative_lengths(self, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: + cumulative_lengths_q = torch.from_numpy(padded_cumsum(self.lengths)).to(dtype=torch.int32, device=device) + cumulative_lengths_k = torch.cat( + [self.current_document_begin, cumulative_lengths_q[1:] + self.sequence_k_past] + ) + return cumulative_lengths_q, cumulative_lengths_k + + def get_max_lengths(self, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: + max_length_q = max(self.lengths) + max_length_k = max(self.max_length_q, self.sequence_k_past + self.lengths[0] - self.current_document_begin) + return ( + torch.full((1,), max_length_q, dtype=torch.int32, device=device), + torch.full((1,), max_length_k, dtype=torch.int32, device=device), + ) + + def get_document_index(self, device: torch.device | None = None) -> torch.Tensor: + return torch.cat( + [ + torch.full((document_length,), i, dtype=torch.int32, device=device) + for i, document_length in enumerate(self.lengths) + ] + ) + + def get_position_index(self, device: torch.device | None = None) -> torch.Tensor: + return torch.cat( + [torch.arange(document_length, dtype=torch.int32, device=device) for document_length in self.lengths] + ) diff --git a/fast_llm/data/preparator/dataset_discovery/prepare.py b/fast_llm/data/preparator/dataset_discovery/prepare.py index 25a29ca3e..f1fc6a63b 100644 --- a/fast_llm/data/preparator/dataset_discovery/prepare.py +++ b/fast_llm/data/preparator/dataset_discovery/prepare.py @@ -11,7 +11,7 @@ import yaml -from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.memmap.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 325d33c43..4d642d3b0 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -23,10 +23,15 @@ BlendedDatasetConfig, DatasetSliceConfig, IndexedDatasetConfig, - MemmapDatasetConfig, SampledDatasetConfig, ) -from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.memmap.config import MemmapDatasetConfig, MemmapIndexDatasetReaderConfig +from fast_llm.data.dataset.memmap.language_model import LanguageModelWriter +from fast_llm.data.dataset.memmap.memmap import MemmapDataset +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.document.patch import PatchDocument +from fast_llm.data.document.range import RangeDocument +from fast_llm.data.document.token import TokenDocument from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import ( ConversationSourceConfig, @@ -37,11 +42,6 @@ from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import Tokenizer -from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig -from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter -from fast_llm.data.sample.patch import PatchSample -from fast_llm.data.sample.range import RangeSample -from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import normalize_probabilities, padded_cumsum @@ -59,7 +59,7 @@ class SpanType(enum.StrEnum): class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): _tokenizer: Tokenizer _data_type: DataType - _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample + _sample_type: typing.ClassVar[type[LanguageModelDocument]] = LanguageModelDocument _config: GPTMemmapDatasetPreparatorConfig def __init__(self, config: ConfigType): @@ -224,7 +224,7 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: use_preference_spans=self._source_schema.has_preference_spans, ) - def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: + def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelDocument: token_spans_by_type = collections.defaultdict(list) image_patches = image_token_maps = image_position_ids = patch_counts = None @@ -332,28 +332,33 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: else: raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}") - sample_size = len(tokens) + len(tokens) - return LanguageModelSample( - TokenSample(tokens, [sample_size]), - ( - RangeSample(token_spans_by_type[SpanType.loss_masking], sample_size) + return LanguageModelDocument( + tokens=TokenDocument(tokens=tokens), + loss_masking_spans=( + RangeDocument(ranges=token_spans_by_type[SpanType.loss_masking]) if self._source_schema.has_loss_masking_span else None ), - ( - RangeSample(token_spans_by_type[SpanType.chosen], sample_size) + chosen_spans=( + RangeDocument(ranges=token_spans_by_type[SpanType.chosen]) if self._source_schema.has_preference_spans else None ), - ( + rejected_spans=( # `tokenize_with_spans` excludes the final eod token from the rejected span, but we want to include it. - RangeSample([(begin, end + 1) for begin, end in token_spans_by_type[SpanType.rejected]], sample_size) + RangeDocument(ranges=[(begin, end + 1) for begin, end in token_spans_by_type[SpanType.rejected]]) if self._source_schema.has_preference_spans else None ), - ( - PatchSample(image_patches, image_token_maps, image_position_ids, sample_size, patch_counts) + image_patches=( + PatchDocument( + patches=image_patches, + token_map=image_token_maps, + positions=image_position_ids, + lengths=patch_counts, + ) if self._source_schema.has_images else None ), diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py deleted file mode 100644 index c5dcf165e..000000000 --- a/fast_llm/data/sample/abstract.py +++ /dev/null @@ -1,270 +0,0 @@ -import abc -import io -import pathlib -import typing - -from fast_llm.config import Config, Configurable, Field, config_class -from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - import torch - - -class Sample(abc.ABC): - @classmethod - @abc.abstractmethod - def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: - pass - - @abc.abstractmethod - def crop(self, begin: int, end: int) -> typing.Self: - pass - - @abc.abstractmethod - def __len__(self) -> int: - pass - - @abc.abstractmethod - def get_padding(self, size: int) -> typing.Self: - pass - - def to_device_(self, device: "torch.device | str"): - pass - - -class Batch(abc.ABC): - # TODO: Relate to `BatchConfig`? - @classmethod - @abc.abstractmethod - def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: - pass - - @abc.abstractmethod - def crop(self, begin: int, end: int) -> typing.Self: - pass - - def to_device_(self, device: "torch.device | str"): - pass - - -@config_class(registry=True) -class MemmapReaderBaseConfig(Config): - """ - Configuration for a memmap reader or reader-like object. - Note: `MemmapDataset` requires a `MemmapIndexedDatasetReader`. - Other readers need to be nested within a `MemmapIndexedDatasetReader` - Note: Reader configs are not typical configs, and do not need to be located in a separate `config.py` file. - """ - - _abstract = True - - @classmethod - def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: - if cls is MemmapReaderBaseConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass, necessary for loading configs where some components could be absent. - return NullReaderConfig._from_dict(default, strict) - return super()._from_dict(default, strict=strict) - - def get_reader(self, buffer: memoryview) -> "MemmapReader|None": - raise NotImplementedError() - - @property - def expected_buffer_size(self) -> int: - """ - The expected buffer size in bytes, including header and footer. Used for self-validation. - """ - raise NotImplementedError() - - def get_metadata(self) -> dict[str, typing.Any]: - raise NotImplementedError() - - @classmethod - def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: - raise NotImplementedError() - - -@config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) -class NullReaderConfig(MemmapReaderBaseConfig): - """ - Configuration for a dynamically disabled reader. - """ - - _abstract = False - - def get_reader(self, buffer: memoryview) -> None: - return None - - @property - def expected_buffer_size(self) -> int: - return 0 - - -@config_class(registry=True) -class MemmapReaderConfig(MemmapReaderBaseConfig): - """ - Configuration for a standard memmap reader. - """ - - # Data location in the file. - begin: int = Field() - end: int = Field() - # Constant strings for alignment safety. - header: typing.ClassVar[bytes] - footer: typing.ClassVar[bytes] - # Additional information about how the dataset was prepared. - preprocessing: PreprocessingConfig = Field() - - @property - def reader_class(self) -> "type[MemmapReader]": - raise NotImplementedError() - - def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None) -> "MemmapReader": - return self.reader_class(self, buffer, model_preprocessing) - - @property - def expected_buffer_size(self) -> int: - """ - The expected buffer size in bytes, including header and footer. Used for self-validation. - """ - return self._expected_buffer_size + len(self.header) + len(self.footer) - - @property - def _expected_buffer_size(self) -> int: - """ - The expected buffer size in bytes, excluding header and footer. Used for self-validation. - """ - raise NotImplementedError() - - @property - def writer_class(self) -> "type[MemmapWriter]": - raise NotImplementedError() - - def get_writer(self, stream: io.BufferedWriter) -> "MemmapWriter": - return self.writer_class(stream) - - def _validate(self): - super()._validate() - Assert.eq(self.end - self.begin, self.expected_buffer_size) - - -@config_class() -class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): - """ - Configuration for a standard memmap reader matching the indexed dataset interface, i.e., - consisting of a list of documents of known lengths. - """ - - def __len__(self) -> int: - raise NotImplementedError() - - @property - def num_tokens(self) -> int: - raise NotImplementedError() - - @property - def reader_class(self) -> "type[MemmapIndexedDatasetReader]": - raise NotImplementedError() - - def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader": - return self.reader_class(self, buffer, model_preprocessing) - - def get_metadata(self) -> dict[str, typing.Any]: - return {"num_tokens": self.num_tokens} - - @classmethod - def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: - return {"num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata)} - - -class MemmapReaderBase[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): - @abc.abstractmethod - def get_document(self, index: int, begin: int, end: int) -> Sample: - pass - - -class MemmapReader[ConfigType: MemmapReaderConfig](MemmapReaderBase[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): - super().__init__(config) - # Note: This is the requirement at reading time (ex. from the model), - # which may differ from how the dataset was actually preprocessed (`config.preprocessing`) - # Compatibility checked in `MemmapDataset`. - self._model_preprocessing = NullPreprocessingConfig if model_preprocessing is None else model_preprocessing - buffer_begin = self._config.begin + len(self._config.header) - buffer_end = self._config.end - len(self._config.footer) - Assert.eq(buffer[self._config.begin : buffer_begin].tobytes(), self._config.header) - Assert.eq(buffer[buffer_end : self._config.end].tobytes(), self._config.footer) - self._buffer = buffer[buffer_begin:buffer_end] - - -class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): - def __len__(self) -> int: - return len(self._config) - - @property - def num_tokens(self) -> int: - return self._config.num_tokens - - @abc.abstractmethod - def get_document_sizes(self) -> "torch.Tensor": - pass - - @abc.abstractmethod - def get_document_size(self, index: int) -> int: - pass - - def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: - raise NotImplementedError() - - -class MemmapWriter(abc.ABC): - def __init__( - self, stream: io.BufferedWriter | pathlib.Path, preprocessing_config: PreprocessingConfig | None = None - ): - self._owns_stream = isinstance(stream, pathlib.Path) - if self._owns_stream: - stream = stream.open("wb") - self._stream = stream - self._preprocessing_config = ( - NullPreprocessingConfig() if preprocessing_config is None else preprocessing_config - ) - - def __enter__(self): - self._begin = self._stream.tell() - self._stream.write(self._get_config_class().header) - return self - - def write(self, document: Sample): - assert hasattr(self, "_begin") and not hasattr(self, "_end") - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - self._stream.write(self._get_config_class().footer) - self._end = self._stream.tell() - if self._owns_stream: - self._stream.close() - - @classmethod - @abc.abstractmethod - def _get_config_class(cls) -> type[MemmapReaderConfig]: - pass - - def get_config(self, offset: int = 0) -> MemmapReaderConfig: - assert hasattr(self, "_end") - return self._get_config(self._begin + offset, self._end + offset) - - @abc.abstractmethod - def _get_config(self, begin: int, end: int): - pass - - @classmethod - def write_dataset( - cls, - stream: io.BufferedWriter, - documents: typing.Iterable[Sample], - preprocessing_config: PreprocessingConfig | None = None, - ) -> MemmapReaderConfig: - with cls(stream, preprocessing_config) as writer: - for document in documents: - writer.write(document) - return writer.get_config() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py deleted file mode 100644 index db7e89d87..000000000 --- a/fast_llm/data/sample/language_model.py +++ /dev/null @@ -1,511 +0,0 @@ -import io -import logging -import pathlib -import tempfile -import typing -import warnings - -import torch - -from fast_llm.config import Field, config_class -from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig -from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig, ImagePatchConfig -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig -from fast_llm.data.sample.abstract import ( - Batch, - MemmapIndexDatasetReaderConfig, - MemmapIndexedDatasetReader, - MemmapReaderBaseConfig, - MemmapWriter, - NullReaderConfig, - Sample, -) -from fast_llm.data.sample.patch import ( - EmptyPatchReader, - PatchBatch, - PatchReader, - PatchReaderBaseConfig, - PatchReaderConfig, - PatchSample, - PatchWriter, -) -from fast_llm.data.sample.range import ( - EmptyRangeReader, - RangeBatch, - RangeReader, - RangeReaderBaseConfig, - RangeReaderConfig, - RangeSample, - RangeWriter, -) -from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import Assert - -logger = logging.getLogger(__name__) - - -class LanguageModelSample(Sample): - def __init__( - self, - tokens: TokenSample, - loss_masking_spans: RangeSample | None = None, - chosen_spans: RangeSample | None = None, - rejected_spans: RangeSample | None = None, - image_patches: PatchSample | None = None, - ): - self.tokens = tokens - self.loss_masking_spans = loss_masking_spans - self.chosen_spans = chosen_spans - self.rejected_spans = rejected_spans - self.image_patches = image_patches - - @classmethod - def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: - return cls( - TokenSample.from_documents([document.tokens for document in documents]), - _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), - _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), - _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), - _merge_optional(PatchSample.from_documents, [document.image_patches for document in documents]), - ) - - def crop(self, begin: int, end: int) -> typing.Self: - return self.__class__( - self.tokens.crop(begin, end), - _crop_optional(self.loss_masking_spans, begin, end), - _crop_optional(self.chosen_spans, begin, end), - _crop_optional(self.rejected_spans, begin, end), - _crop_optional(self.image_patches, begin, end), - ) - - def __len__(self) -> int: - return len(self.tokens) - - def get_padding(self, size: int) -> typing.Self: - return LanguageModelSample( - self.tokens.get_padding(size), - None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), - None if self.chosen_spans is None else self.chosen_spans.get_padding(size), - None if self.rejected_spans is None else self.rejected_spans.get_padding(size), - None if self.image_patches is None else self.image_patches.get_padding(size), - ) - - def to_device_(self, device: "torch.device | str"): - self.tokens.to_device_(device) - if self.loss_masking_spans is not None: - self.loss_masking_spans.to_device_(device) - if self.chosen_spans is not None: - self.chosen_spans.to_device_(device) - if self.rejected_spans is not None: - self.rejected_spans.to_device_(device) - if self.image_patches is not None: - self.image_patches.to_device_(device) - - -class LanguageModelBatch(Batch): - def __init__( - self, - tokens: TokenBatch, - loss_masking_spans: RangeBatch | None = None, - chosen_spans: RangeBatch | None = None, - rejected_spans: RangeBatch | None = None, - image_patches: PatchBatch | None = None, - ): - self.tokens = tokens - self.loss_masking_spans = loss_masking_spans - self.chosen_spans = chosen_spans - self.rejected_spans = rejected_spans - self.image_patches = image_patches - - @classmethod - def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: - return cls( - TokenBatch.from_samples([sample.tokens for sample in samples]), - _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), - _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), - _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), - _merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]), - ) - - def crop(self, begin: int, end: int) -> typing.Self: - return self.__class__( - self.tokens.crop(begin, end), - _crop_optional(self.loss_masking_spans, begin, end), - _crop_optional(self.chosen_spans, begin, end), - _crop_optional(self.rejected_spans, begin, end), - _crop_optional(self.image_patches, begin, end), - ) - - def to_device_(self, device: "torch.device | str"): - self.tokens.to_device_(device) - if self.loss_masking_spans is not None: - self.loss_masking_spans.to_device_(device) - if self.chosen_spans is not None: - self.chosen_spans.to_device_(device) - if self.rejected_spans is not None: - self.rejected_spans.to_device_(device) - if self.image_patches is not None: - self.image_patches.to_device_(device) - - -def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: - return None if any(arg is None for arg in args) else fn(args) - - -def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: - return None if sample_or_batch is None else sample_or_batch.crop(begin, end) - - -@config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) -class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): - _abstract = False - header: typing.ClassVar[bytes] = b"lm begin" - footer: typing.ClassVar[bytes] = b"lm end" - tokens: TokenReaderConfig = Field() - # Using dynamic type for optional readers for enabling/disabling - loss_masking_spans: MemmapReaderBaseConfig = Field() - chosen_spans: MemmapReaderBaseConfig = Field() - rejected_spans: MemmapReaderBaseConfig = Field() - image_patches: MemmapReaderBaseConfig = Field() - - def _validate(self) -> None: - super()._validate() - if isinstance(self.preprocessing, NullPreprocessingConfig): - # Address missing config, mostly for backward compatibility. - # TODO: We can't tell which dataset this comes from. - logger.warning( - f"Preprocessing configuration not specified for dataset reader, generating partial configuration from known parameters." - ) - if isinstance(self.image_patches, PatchReaderConfig): - Assert.eq(len(patch_shape := self.image_patches.patch_shape), 3) - image_patches = ImagePatchConfig(height=patch_shape[1], width=patch_shape[2]) - else: - image_patches = NullPreprocessingConfig() - self.preprocessing = LanguageModelPreprocessingConfig( - image_patches=image_patches, - use_loss_masking_spans=isinstance(self.loss_masking_spans, RangeReaderConfig), - use_preference_spans=isinstance(self.chosen_spans, RangeReaderConfig), - ) - # TODO: Avoid duplicated information. - Assert.custom( - isinstance, - self.loss_masking_spans, - RangeReaderConfig if self.preprocessing.use_loss_masking_spans else NullReaderConfig, - ) - Assert.custom( - isinstance, - self.chosen_spans, - RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, - ) - Assert.custom( - isinstance, - self.rejected_spans, - RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, - ) - if self.preprocessing.use_image_patches: - Assert.custom(isinstance, self.image_patches, PatchReaderConfig) - Assert.eq(self.image_patches.patch_shape, self.preprocessing.image_patches.patch_shape) - Assert.eq(self.image_patches.data_type, DataType.uint8) - else: - Assert.custom(isinstance, self.image_patches, NullReaderConfig) - - def __len__(self) -> int: - return len(self.tokens) - - @property - def num_tokens(self) -> int: - return self.tokens.num_tokens - - @property - def reader_class(self) -> "type[LanguageModelReader]": - return LanguageModelReader - - @property - def writer_class(self) -> "type[LanguageModelWriter]": - return LanguageModelWriter - - @property - def _expected_buffer_size(self) -> int: - return ( - self.tokens.expected_buffer_size - + self.loss_masking_spans.expected_buffer_size - + self.chosen_spans.expected_buffer_size - + self.rejected_spans.expected_buffer_size - + self.image_patches.expected_buffer_size - ) - - def get_metadata(self) -> dict[str, typing.Any]: - out = super().get_metadata() - out["tokens"] = self.tokens.get_metadata() - if not isinstance(self.loss_masking_spans, NullReaderConfig): - out["loss_masking_spans"] = self.loss_masking_spans.get_metadata() - if not isinstance(self.chosen_spans, NullReaderConfig): - out["chosen_spans"] = self.chosen_spans.get_metadata() - if not isinstance(self.rejected_spans, NullReaderConfig): - out["rejected_spans"] = self.rejected_spans.get_metadata() - if not isinstance(self.image_patches, NullReaderConfig): - out["image_patches"] = self.image_patches.get_metadata() - return out - - @classmethod - def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: - out = super().blend_metadata(metadata) - out["tokens"] = TokenReaderConfig.blend_metadata([metadata_["tokens"] for metadata_ in metadata]) - if "loss_masking_spans" in metadata[0]: - out["loss_masking_spans"] = RangeReaderConfig.blend_metadata( - [metadata_["loss_masking_spans"] for metadata_ in metadata] - ) - if "chosen_spans" in metadata[0]: - out["chosen_spans"] = RangeReaderConfig.blend_metadata( - [metadata_["chosen_spans"] for metadata_ in metadata] - ) - if "rejected_spans" in metadata[0]: - out["image_patches"] = RangeReaderConfig.blend_metadata( - [metadata_["image_patches"] for metadata_ in metadata] - ) - if "image_patches" in metadata[0]: - out["image_patches"] = PatchReaderConfig.blend_metadata( - [metadata_["image_patches"] for metadata_ in metadata] - ) - return out - - -class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): - _model_preprocessing: LanguageModelPreprocessingConfig - - def __init__( - self, - config: ConfigType, - buffer: memoryview, - model_preprocessing: LanguageModelPreprocessingConfig | None = None, - ): - super().__init__(config, buffer, model_preprocessing) - self._config.preprocessing.check_compatibility(self._model_preprocessing) - # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. - self._tokens = self._config.tokens.get_reader(buffer) - - if self._model_preprocessing.use_loss_masking_spans: - if isinstance(self._config.loss_masking_spans, NullReaderConfig): - # TODO: We can't tell which dataset this comes from. - warnings.warn( - f"The model uses loss masking spans, but the dataset does not specify any." - " Assuming empty span lists." - ) - # TODO: this might have the same issue as empty PatchReaderConfig, so RangeReaderConfig.create_empty might be needed - self._loss_masking_spans = EmptyRangeReader(RangeReaderBaseConfig()) - else: - self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) - - if self._model_preprocessing.use_preference_spans: - self._chosen_spans = self._config.chosen_spans.get_reader(buffer) - self._rejected_spans = self._config.rejected_spans.get_reader(buffer) - - if self._model_preprocessing.use_image_patches: - model_image_preprocessing: ImagePatchConfig = self._model_preprocessing.image_patches - if isinstance(self._config.image_patches, NullReaderConfig): - warnings.warn( - f"The model uses image patches, but the dataset does not specify any." - " Assuming empty patch lists." - ) - self._image_patches = EmptyPatchReader( - PatchReaderBaseConfig(patch_shape=model_image_preprocessing.patch_shape, data_type=DataType.uint8), - ) - else: - self._image_patches = self._config.image_patches.get_reader(buffer) - - # TODO: Make this configurable. (Add to `model_preprocessing`?) - self._image_normalization_config = ImageNormalizationConfig() - - @property - def num_tokens(self) -> int: - return self._config.tokens.num_tokens - - def get_document(self, index: int, begin: int, end: int) -> Sample: - if self._model_preprocessing.use_image_patches: - image_patches = self._image_patches.get_document(index, begin, end) - image_patches.patches = self._image_normalization_config.normalize(image_patches.patches) - else: - image_patches = None - return LanguageModelSample( - self._tokens.get_document(index, begin, end), - ( - self._loss_masking_spans.get_document(index, begin, end) - if self._model_preprocessing.use_loss_masking_spans - else None - ), - ( - self._chosen_spans.get_document(index, begin, end) - if self._model_preprocessing.use_preference_spans - else None - ), - ( - self._rejected_spans.get_document(index, begin, end) - if self._model_preprocessing.use_preference_spans - else None - ), - image_patches, - ) - - def get_document_sizes(self) -> torch.Tensor: - return self._tokens.get_document_sizes() - - def get_document_size(self, index: int) -> int: - return self._tokens.get_document_size(index) - - def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: - begin_index, end_index, token_metadata = self._tokens.get_split(begin_ratio, end_ratio) - metadata = { - "num_tokens": token_metadata["num_tokens"], - "tokens": token_metadata, - } - if hasattr(self, "_loss_masking_spans") and isinstance(self._loss_masking_spans, RangeReader): - metadata["loss_masking_spans"] = self._loss_masking_spans.get_split(begin_index, end_index) - if hasattr(self, "_chosen_spans") and isinstance(self._chosen_spans, RangeReader): - metadata["chosen_spans"] = self._chosen_spans.get_split(begin_index, end_index) - if hasattr(self, "_rejected_spans") and isinstance(self._rejected_spans, RangeReader): - metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index) - if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader): - metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index) - - return begin_index, end_index, metadata - - -class LanguageModelWriter(MemmapWriter): - _preprocessing_config: LanguageModelPreprocessingConfig - - def __enter__(self): - super().__enter__() - self._size_cumsum = [0] - self._data_type = None - - self._directory = tempfile.TemporaryDirectory() - self._path = pathlib.Path(self._directory.name) - # We write intermediate results in separate files so we don't need to iterate over the dataset multiple times. - self._token_writer = TokenWriter(self._path.joinpath("tokens")).__enter__() - if self._preprocessing_config.use_loss_masking_spans: - self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() - if self._preprocessing_config.use_preference_spans: - self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() - self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() - if self._preprocessing_config.use_image_patches: - self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() - return self - - def write(self, document: LanguageModelSample): - super().write(document) - # Write tokens. - self._token_writer.write(document.tokens) - - # Write loss masking spans. - if self._preprocessing_config.use_loss_masking_spans: - assert document.loss_masking_spans is not None - self._loss_masking_span_writer.write(document.loss_masking_spans) - - # Write preference spans. - if self._preprocessing_config.use_preference_spans: - assert document.chosen_spans is not None - assert document.rejected_spans is not None - self._chosen_spans_writer.write(document.chosen_spans) - self._rejected_spans_writer.write(document.rejected_spans) - - # Write image patches - if self._preprocessing_config.use_image_patches: - assert document.image_patches is not None - self._image_patches_writer.write(document.image_patches) - - def __exit__(self, exc_type, exc_val, exc_tb): - self._token_writer.__exit__(exc_type, exc_val, exc_tb) - if self._preprocessing_config.use_loss_masking_spans: - self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) - if self._preprocessing_config.use_preference_spans: - self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) - self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) - if self._preprocessing_config.use_image_patches: - self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) - - if exc_type is None: - # A dummy config so we can verify the begin and end offsets. - config = self._get_config(self._begin, None) - _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) - - if self._preprocessing_config.use_loss_masking_spans: - _copy_chunked( - self._path.joinpath("loss_masking_spans"), - self._stream, - config.loss_masking_spans.begin, - config.loss_masking_spans.end, - ) - if self._preprocessing_config.use_preference_spans: - _copy_chunked( - self._path.joinpath("chosen_spans"), - self._stream, - config.chosen_spans.begin, - config.chosen_spans.end, - ) - _copy_chunked( - self._path.joinpath("rejected_spans"), - self._stream, - config.rejected_spans.begin, - config.rejected_spans.end, - ) - - if self._preprocessing_config.use_image_patches: - _copy_chunked( - self._path.joinpath("image_patches"), - self._stream, - config.image_patches.begin, - config.image_patches.end, - ) - - self._directory.cleanup() - super().__exit__(exc_type, exc_val, exc_tb) - - @classmethod - def _get_config_class(cls) -> type[LanguageModelReaderConfig]: - return LanguageModelReaderConfig - - def _get_config(self, begin: int, end: int | None): - tokens = self._token_writer.get_config(begin + len(LanguageModelReaderConfig.header)) - offset = tokens.end - if self._preprocessing_config.use_loss_masking_spans: - loss_masking_spans = self._loss_masking_span_writer.get_config(offset) - offset = loss_masking_spans.end - else: - loss_masking_spans = NullReaderConfig() - if self._preprocessing_config.use_preference_spans: - chosen_spans = self._chosen_spans_writer.get_config(offset) - offset = chosen_spans.end - rejected_spans = self._rejected_spans_writer.get_config(offset) - offset = rejected_spans.end - else: - chosen_spans = NullReaderConfig() - rejected_spans = NullReaderConfig() - if self._preprocessing_config.use_image_patches: - image_patches = self._image_patches_writer.get_config(offset) - offset = image_patches.end - else: - image_patches = NullReaderConfig() - - if end is None: - end = offset + len(LanguageModelReaderConfig.footer) - - return LanguageModelReaderConfig( - begin=begin, - end=end, - tokens=tokens, - loss_masking_spans=loss_masking_spans, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, - image_patches=image_patches, - preprocessing=self._preprocessing_config, - ) - - -def _copy_chunked(path: pathlib.Path, stream: io.BufferedWriter, expected_begin: int, expected_end: int): - # Copy temporary file content in chunks of 100 MB. - Assert.eq(stream.tell(), expected_begin) - with path.open("rb") as input_stream: - while data := input_stream.read(100000000): - stream.write(data) - Assert.eq(stream.tell(), expected_end) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py deleted file mode 100644 index 0be91f0c8..000000000 --- a/fast_llm/data/sample/patch.py +++ /dev/null @@ -1,359 +0,0 @@ -import math -import typing - -import numpy as np -import torch - -from fast_llm.config import Field, config_class -from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import ( - Batch, - MemmapReader, - MemmapReaderBase, - MemmapReaderBaseConfig, - MemmapReaderConfig, - MemmapWriter, - Sample, -) -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import Assert, get_unique, padded_cumsum - - -def filter_lengths(lengths: list[int], filter: torch.Tensor) -> list[int]: - length_cumsum = padded_cumsum(lengths) - filtered_lengths = (filter[begin:end].sum().item() for begin, end in zip(length_cumsum[:-1], length_cumsum[1:])) - return [length for length in filtered_lengths if length > 0] - - -class PatchSample(Sample): - """ - A reusable component holding a set of fixed-shape patches (ex. images, audio, video), - each of which providing a single token embedding in a multimodal model. - """ - - def __init__( - self, - patches: torch.Tensor, - token_map: torch.Tensor, - positions: torch.Tensor, - sample_size: int, - lengths: list[int] | None = None, - ): - # Tensor of dimensions (patch, *patch_shape) - self.patches = patches - # Mapping from patch to token index - self.token_map = token_map - # A position identifier for each patch in the patch grid. - Assert.eq(positions.shape, (self.patches.size(0), self.patches.ndim - 2)) - self.positions = positions - # Number of tokens in the sample (not the number of patches) - self.sample_size = sample_size - # Length of each patch group (ex. image) in the sample. TODO: Use cumsums instead? - if lengths is None: - lengths = [len(patches)] - else: - Assert.eq(sum(lengths), len(patches)) - self.lengths = lengths - - @classmethod - def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: - total_size = 0 - embedding_maps = [] - for document in documents: - embedding_maps.append(document.token_map + total_size) - total_size += document.sample_size - return cls( - torch.cat([document.patches for document in documents]), - torch.cat(embedding_maps), - torch.cat([document.positions for document in documents]), - total_size, - sum((document.lengths for document in documents), []), - ) - - def crop(self, begin: int, end: int) -> typing.Self: - sample_size = end - begin - patch_filter = (self.token_map >= begin) & (self.token_map < end) - return self.__class__( - self.patches[patch_filter], - self.token_map[patch_filter] - begin, - self.positions[patch_filter], - sample_size, - filter_lengths(self.lengths, patch_filter), - ) - - def __len__(self) -> int: - return self.sample_size - - def get_padding(self, size: int) -> typing.Self: - return self.__class__( - self.patches.new_empty((0, *self.patches.shape[1:])), - self.token_map.new_empty(0), - self.positions.new_empty([0, self.patches.ndim - 2]), - size, - [], - ) - - def to_device_(self, device: "torch.device | str"): - self.patches = self.patches.to(device, non_blocking=True) - self.token_map = self.token_map.to(device, non_blocking=True) - self.positions = self.positions.to(device, non_blocking=True) - - -class PatchBatch(Batch): - def __init__( - self, - patches: torch.Tensor, - sample_map: torch.Tensor, - token_map: torch.Tensor, - positions: torch.Tensor, - num_samples: int, - sample_size: int, - lengths: list[int], - ): - # Concatenated along patch index rather than stacked since the lengths are not constant - self.patches = patches - # Mapping from patch to sample index - self.sample_map = sample_map - self.token_map = token_map - self.positions = positions - self.num_samples = num_samples - self.sample_size = sample_size - self.lengths = lengths - - @classmethod - def from_samples(cls, samples: typing.Sequence[PatchSample]) -> typing.Self: - return cls( - torch.cat([sample.patches for sample in samples]), - torch.cat( - [torch.full_like(sample.token_map, sample_index) for sample_index, sample in enumerate(samples)] - ), - torch.cat([sample.token_map for sample in samples]), - torch.cat([sample.positions for sample in samples]), - len(samples), - get_unique(sample.sample_size for sample in samples), - [length for sample in samples for length in sample.lengths], - ) - - def crop(self, begin: int, end: int) -> typing.Self: - sample_size = end - begin - patch_filter = (self.token_map >= begin) & (self.token_map < end) - - return self.__class__( - self.patches[patch_filter], - self.sample_map[patch_filter], - self.token_map[patch_filter], - self.positions[patch_filter], - self.num_samples, - sample_size, - filter_lengths(self.lengths, patch_filter), - ) - - def to_device_(self, device: "torch.device | str"): - self.patches = self.patches.to(device, non_blocking=True) - self.sample_map = self.sample_map.to(device, non_blocking=True) - self.token_map = self.token_map.to(device, non_blocking=True) - self.positions = self.positions.to(device, non_blocking=True) - - -@config_class() -class PatchReaderBaseConfig(MemmapReaderBaseConfig): - _abstract = False - patch_shape: tuple[int, ...] = Field() - data_type: DataType = Field() - - @property - def patch_size(self) -> int: - return math.prod(self.patch_shape) - - @property - def grid_dims(self) -> int: - return len(self.patch_shape) - 1 - - -@config_class(dynamic_type={MemmapReaderBaseConfig: "patch"}) -class PatchReaderConfig(PatchReaderBaseConfig, MemmapReaderConfig): - header: typing.ClassVar[bytes] = b"patch begin" - footer: typing.ClassVar[bytes] = b"patch end" - num_documents: int = Field() - num_patches: int = Field() - num_patch_groups: int = Field() - - def __len__(self) -> int: - return self.num_documents - - @property - def reader_class(self) -> "type[PatchReader]": - return PatchReader - - @property - def writer_class(self) -> "type[PatchWriter]": - return PatchWriter - - @property - def _expected_buffer_size(self) -> int: - return ( - self.num_patches * self.patch_size * self.data_type.torch.itemsize - + ((1 + self.grid_dims) * self.num_patches + self.num_patch_groups + 2 * self.num_documents + 2) - * torch.int32.itemsize - ) - - def get_metadata(self) -> dict[str, typing.Any]: - return { - "num_documents": self.num_documents, - "num_patches": self.num_patches, - "num_patch_groups": self.num_patch_groups, - "num_pixels": self.patch_size * self.num_patches, - "patch_shape": self.patch_shape, - "data_type": str(self.data_type), - } - - @classmethod - def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: - return { - "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), - "num_patches": sum(metadata_["num_patches"] for metadata_ in metadata), - "num_patch_groups": sum(metadata_["num_patch_groups"] for metadata_ in metadata), - "num_pixels": sum(metadata_["num_pixels"] for metadata_ in metadata), - "patch_shape": get_unique(metadata_["patch_shape"] for metadata_ in metadata), - "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), - } - - -class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): - super().__init__(config, buffer, model_preprocessing) - self._patches = torch.frombuffer( - self._buffer, - dtype=self._config.data_type.torch, - count=self._config.num_patches * self._config.patch_size, - ).view(self._config.num_patches, *self._config.patch_shape) - offset = self._patches.nbytes - self._token_map = torch.frombuffer( - self._buffer, - dtype=torch.int32, - count=self._config.num_patches, - offset=offset, - ) - offset += self._token_map.nbytes - self._positions = torch.frombuffer( - self._buffer, - dtype=torch.int32, - count=self._config.num_patches * self._config.grid_dims, - offset=offset, - ).view(self._config.num_patches, self._config.grid_dims) - offset += self._positions.nbytes - self._patch_count_cumsums = torch.frombuffer( - self._buffer, - dtype=torch.int32, - count=self._config.num_documents + 1, - offset=offset, - ) - offset += self._patch_count_cumsums.nbytes - self._group_lengths = torch.frombuffer( - self._buffer, - dtype=torch.int32, - count=self._config.num_patch_groups, - offset=offset, - ) - offset += self._group_lengths.nbytes - self._group_count_cumsums = torch.frombuffer( - self._buffer, - dtype=torch.int32, - count=self._config.num_documents + 1, - offset=offset, - ) - - def get_document(self, index: int, begin: int, end: int) -> Sample: - token_map = self._token_map[ - token_slice := slice(self._patch_count_cumsums[index], self._patch_count_cumsums[index + 1]) - ] - patch_filter = (token_map >= begin) & (token_map < end) - return PatchSample( - self._patches[token_slice][patch_filter], - token_map[patch_filter] - begin, - self._positions[token_slice][patch_filter], - end - begin, - filter_lengths( - self._group_lengths[self._group_count_cumsums[index] : self._group_count_cumsums[index + 1]].tolist(), - patch_filter, - ), - ) - - def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: - Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) - num_patches = self._patch_count_cumsums[end_index].item() - self._patch_count_cumsums[begin_index].item() - return { - "num_documents": end_index - begin_index, - "num_patches": num_patches, - "num_patch_groups": self._group_count_cumsums[end_index].item() - - self._group_count_cumsums[begin_index].item(), - "num_pixels": self._config.patch_size * num_patches, - "patch_shape": self._config.patch_shape, - "data_type": str(self._config.data_type), - } - - -class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]): - def get_document(self, index: int, begin: int, end: int) -> Sample: - return PatchSample( - torch.empty(0, *self._config.patch_shape, dtype=self._config.data_type.torch), - torch.empty(0, dtype=torch.int32), - torch.empty(0, self._config.grid_dims, dtype=torch.int32), - end - begin, - ) - - -class PatchWriter(MemmapWriter): - def __enter__(self): - super().__enter__() - self._patch_count_cumsum = [0] - self._group_count_cumsum = [0] - self._token_map = [] - self._positions = [] - self._group_lengths = [] - self._data_type = None - self._patch_shape = None - return self - - def write(self, document: PatchSample): - super().write(document) - if self._data_type is None: - self._data_type = document.patches.dtype - else: - Assert.eq(self._data_type, document.patches.dtype) - if self._patch_shape is None: - self._patch_shape = tuple(document.patches.shape[1:]) - else: - Assert.eq(self._patch_shape, document.patches.shape[1:]) - self._stream.write(document.patches.numpy().tobytes()) - self._token_map.extend(document.token_map) - self._positions.extend(document.positions) - self._patch_count_cumsum.append(self._patch_count_cumsum[-1] + len(document.patches)) - self._group_count_cumsum.append(self._group_count_cumsum[-1] + len(document.lengths)) - self._group_lengths.extend(document.lengths) - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - Assert.lt(self._patch_count_cumsum[-1], np.iinfo(np.int32).max) - self._stream.write(np.array(self._token_map, dtype=np.int32).tobytes(order="C")) - self._stream.write(np.array(self._positions, dtype=np.int32).tobytes(order="C")) - self._stream.write(np.array(self._patch_count_cumsum, dtype=np.int32).tobytes(order="C")) - self._stream.write(np.array(self._group_lengths, dtype=np.int32).tobytes(order="C")) - self._stream.write(np.array(self._group_count_cumsum, dtype=np.int32).tobytes(order="C")) - super().__exit__(exc_type, exc_val, exc_tb) - - @classmethod - def _get_config_class(cls) -> type[PatchReaderConfig]: - return PatchReaderConfig - - def _get_config(self, begin: int, end: int): - return PatchReaderConfig( - begin=begin, - end=end, - num_documents=len(self._patch_count_cumsum) - 1, - num_patches=self._patch_count_cumsum[-1], - num_patch_groups=self._group_count_cumsum[-1], - patch_shape=self._patch_shape, - data_type=DataType.from_torch(self._data_type), - preprocessing=self._preprocessing_config, - ) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py deleted file mode 100644 index f57ee04d9..000000000 --- a/fast_llm/data/sample/range.py +++ /dev/null @@ -1,173 +0,0 @@ -import typing - -import numpy as np -import torch - -from fast_llm.config import Field, config_class -from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import ( - Batch, - MemmapReader, - MemmapReaderBase, - MemmapReaderBaseConfig, - MemmapReaderConfig, - MemmapWriter, - Sample, -) -from fast_llm.utils import Assert, get_unique - - -def crop_ranges(ranges: list[tuple[int, int]], begin: int, end: int) -> list[tuple[int, int]]: - cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in ranges) - return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] - - -class RangeSample(Sample): - """ - A reusable component holding a set of ranges in a sample. - """ - - def __init__(self, ranges: list[tuple[int, int]], sample_size: int): - self.ranges = ranges - self.sample_size = sample_size - - @classmethod - def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: - """ - Used to merge ranges from multiple documents, i.e. when multiple docuemnts are packed together. - """ - document: RangeSample - ranges = [] - sample_size = 0 - for document in documents: - for begin, end in document.ranges: - ranges.append((begin + sample_size, end + sample_size)) - sample_size += document.sample_size - return cls(ranges, sample_size) - - def crop(self, begin: int, end: int) -> typing.Self: - return self.__class__(crop_ranges(self.ranges, begin, end), end - begin) - - def __len__(self) -> int: - return self.sample_size - - def get_padding(self, size: int) -> typing.Self: - return self.__class__([], size) - - -class RangeBatch(Batch): - def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): - self.sample_size = sample_size - self.ranges = ranges - - @classmethod - def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: - return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples)) - - def crop(self, begin: int, end: int) -> typing.Self: - return self.__class__([crop_ranges(sample_ranges, begin, end) for sample_ranges in self.ranges], end - begin) - - -@config_class() -class RangeReaderBaseConfig(MemmapReaderBaseConfig): - _abstract = False - - -@config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) -class RangeReaderConfig(RangeReaderBaseConfig, MemmapReaderConfig): - header: typing.ClassVar[bytes] = b"range begin" - footer: typing.ClassVar[bytes] = b"range end" - num_documents: int = Field() - num_ranges: int = Field() - - @property - def reader_class(self) -> "type[RangeReader]": - return RangeReader - - @property - def writer_class(self) -> "type[RangeWriter]": - return RangeWriter - - @property - def _expected_buffer_size(self) -> int: - return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize - - def get_metadata(self) -> dict[str, typing.Any]: - return { - "num_documents": self.num_documents, - "num_ranges": self.num_ranges, - } - - @classmethod - def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: - return { - "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), - "num_ranges": sum(metadata_["num_ranges"] for metadata_ in metadata), - } - - -class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): - super().__init__(config, buffer, model_preprocessing) - self._ranges = torch.frombuffer( - self._buffer, - dtype=torch.int32, - count=self._config.num_ranges * 2, - ).view(-1, 2) - self._count_cumsums = torch.frombuffer( - self._buffer, - dtype=torch.int32, - count=self._config.num_documents + 1, - offset=self._ranges.nbytes, - ) - - def get_document(self, index: int, begin: int, end: int) -> Sample: - sample_size = end - begin - cropped_ranges = ( - (max(begin_ - begin, 0), min(end_ - begin, sample_size)) - for begin_, end_ in self._ranges[self._count_cumsums[index] : self._count_cumsums[index + 1]].tolist() - ) - return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) - - def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: - Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) - return { - "num_documents": end_index - begin_index, - "num_ranges": self._count_cumsums[end_index].item() - self._count_cumsums[begin_index].item(), - } - - -class EmptyRangeReader[ConfigType: RangeReaderBaseConfig](MemmapReaderBase[ConfigType]): - def get_document(self, index: int, begin: int, end: int) -> Sample: - return RangeSample([], end - begin) - - -class RangeWriter(MemmapWriter): - def __enter__(self): - super().__enter__() - self._count_cumsum = [0] - return self - - def write(self, document: RangeSample): - super().write(document) - self._stream.write(np.array(document.ranges, dtype=np.int32).tobytes(order="C")) - self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - Assert.lt(self._count_cumsum[-1], np.iinfo(np.int32).max) - self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) - super().__exit__(exc_type, exc_val, exc_tb) - - @classmethod - def _get_config_class(cls) -> type[RangeReaderConfig]: - return RangeReaderConfig - - def _get_config(self, begin: int, end: int): - return RangeReaderConfig( - begin=begin, - end=end, - num_documents=len(self._count_cumsum) - 1, - num_ranges=self._count_cumsum[-1], - preprocessing=self._preprocessing_config, - ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py deleted file mode 100644 index 17078cef9..000000000 --- a/fast_llm/data/sample/token.py +++ /dev/null @@ -1,265 +0,0 @@ -import typing - -import numpy as np -import torch - -from fast_llm.config import Field, config_class -from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import ( - Batch, - MemmapIndexedDatasetReader, - MemmapReaderBaseConfig, - MemmapReaderConfig, - MemmapWriter, - Sample, -) -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import Assert, get_unique, padded_cumsum - - -def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: - if len(lengths) == 1: - # Shortcut for the frequent case of a single document. - return [end - begin] - begin_ = 0 - lengths_ = [] - for length in lengths: - end_ = begin_ + length - cropped_length = min(end_, end) - max(begin_, begin) - if cropped_length > 0: - lengths_.append(cropped_length) - if end_ > end: - break - begin_ = end_ - return lengths_ - - -class TokenSample(Sample): - def __init__( - self, - tokens: torch.Tensor, - lengths: list[int] | None = None, - sequence_k_past: int = 0, - current_document_begin: int = 0, - ): - self.tokens = tokens - # Length of each document in the sample. TODO: Use cumsums instead? - if lengths is None: - lengths = [len(tokens)] - else: - Assert.eq(sum(lengths), len(tokens)) - self.lengths = lengths - self.sequence_k_past = sequence_k_past - self.current_document_begin = current_document_begin - - @classmethod - def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: - return cls( - torch.cat([document.tokens for document in documents]), - sum((document.lengths for document in documents), []), - ) - - def crop(self, begin: int, end: int) -> typing.Self: - Assert.eq(self.sequence_k_past, self.current_document_begin, 0) - - document_begin = 0 - lengths_ = [] - current_document_begin = None - for length in self.lengths: - document_end = document_begin + length - cropped_length = min(document_end, end) - max(document_begin, begin) - if cropped_length > 0: - lengths_.append(cropped_length) - if not current_document_begin: - current_document_begin = document_begin - if document_end > end: - break - document_begin = document_end - - return self.__class__(self.tokens[begin:end], lengths_, begin, current_document_begin) - - def __len__(self) -> int: - return len(self.tokens) - - def get_padding(self, size: int) -> typing.Self: - return self.__class__(torch.full([size], -100, dtype=self.tokens.dtype), [size]) - - def to_device_(self, device: "torch.device | str"): - # Also standardize the dtype while we're here. - self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) - - def get_cumulative_lengths(self, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: - cumulative_lengths_q = torch.from_numpy(padded_cumsum(self.lengths)).to(dtype=torch.int32, device=device) - cumulative_lengths_k = torch.cat( - [self.current_document_begin, cumulative_lengths_q[1:] + self.sequence_k_past] - ) - return cumulative_lengths_q, cumulative_lengths_k - - def get_max_lengths(self, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: - max_length_q = max(self.lengths) - max_length_k = max(self.max_length_q, self.sequence_k_past + self.lengths[0] - self.current_document_begin) - return ( - torch.full((1,), max_length_q, dtype=torch.int32, device=device), - torch.full((1,), max_length_k, dtype=torch.int32, device=device), - ) - - def get_document_index(self, device: torch.device | None = None) -> torch.Tensor: - return torch.cat( - [ - torch.full((document_length,), i, dtype=torch.int32, device=device) - for i, document_length in enumerate(self.lengths) - ] - ) - - def get_position_index(self, device: torch.device | None = None) -> torch.Tensor: - return torch.cat( - [torch.arange(document_length, dtype=torch.int32, device=device) for document_length in self.lengths] - ) - - -class TokenBatch(Batch): - def __init__(self, tokens: torch.Tensor, lengths: list[list[int]] | None) -> None: - self.tokens = tokens - if lengths is None: - lengths = [[tokens.size(1)]] * tokens.size(0) - self.lengths = lengths - - @classmethod - def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: - return cls( - torch.stack([sample.tokens for sample in samples]), - [sample.lengths for sample in samples], - ) - - def crop(self, begin: int, end: int) -> typing.Self: - return self.__class__( - self.tokens[:, begin:end], - [crop_lengths(lengths, begin, end) for lengths in self.lengths], - ) - - def to_device_(self, device: "torch.device | str"): - # Also standardize the dtype while we're here. - self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) - - -@config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) -class TokenReaderConfig(MemmapReaderConfig): - _abstract = False - header: typing.ClassVar[bytes] = b"token begin" - footer: typing.ClassVar[bytes] = b"token end" - num_documents: int = Field() - num_tokens: int = Field() - data_type: DataType = Field() - - def __len__(self) -> int: - return self.num_documents - - @property - def reader_class(self) -> "type[TokenReader]": - return TokenReader - - @property - def writer_class(self) -> "type[TokenWriter]": - return TokenWriter - - @property - def _expected_buffer_size(self) -> int: - return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize - - def get_metadata(self) -> dict[str, typing.Any]: - return { - "num_tokens": self.num_tokens, - "num_documents": self.num_documents, - "data_type": str(self.data_type), - } - - @classmethod - def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: - return { - "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata), - "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), - "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), - } - - -class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): - super().__init__(config, buffer, model_preprocessing) - self._tokens = torch.frombuffer( - self._buffer, - dtype=self._config.data_type.torch, - count=self._config.num_tokens, - ) - self._size_cumsums = torch.frombuffer( - self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._tokens.nbytes - ) - - def get_document(self, index: int, begin: int, end: int) -> Sample: - begin_ = self._size_cumsums[index].item() - # Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues. - # Convert begin and end to int to avoid numpy dtype overflow when adding to begin_ - return TokenSample(self._tokens[begin_ + begin : begin_ + end].to(torch.int64), [end - begin]) - - def get_document_sizes(self) -> torch.Tensor: - return self._size_cumsums[1:] - self._size_cumsums[:-1] - - def get_document_size(self, index: int) -> int: - return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() - - def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: - Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1]) - begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens) - end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens) - - return ( - begin_index, - end_index, - { - "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), - "num_documents": end_index - begin_index, - "data_type": str(self._config.data_type), - }, - ) - - -def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: - left = torch.searchsorted(cumsum, value, side="right") - if left == len(cumsum): - return left.item() - return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() - - -class TokenWriter(MemmapWriter): - def __enter__(self): - super().__enter__() - self._size_cumsum = [0] - self._data_type = None - return self - - def write(self, document: TokenSample): - super().write(document) - if self._data_type is None: - self._data_type = document.tokens.dtype - else: - Assert.eq(self._data_type, document.tokens.dtype) - self._stream.write(document.tokens.numpy().tobytes()) - self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) - super().__exit__(exc_type, exc_val, exc_tb) - - @classmethod - def _get_config_class(cls) -> type[TokenReaderConfig]: - return TokenReaderConfig - - def _get_config(self, begin: int, end: int): - return TokenReaderConfig( - begin=begin, - end=end, - num_documents=len(self._size_cumsum) - 1, - num_tokens=self._size_cumsum[-1], - data_type=DataType.from_torch(self._data_type), - preprocessing=self._preprocessing_config, - ) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 387610a46..fd02a6dc3 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -6,8 +6,6 @@ import torch import transformers.modeling_outputs -from fast_llm.data.sample.language_model import LanguageModelBatch -from fast_llm.data.sample.token import TokenBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index e32b78ff9..2f96f6f91 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -5,7 +5,7 @@ import torch -from fast_llm.batch.language_model import LanguageModelBatchNew +from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.inference.runner import InferenceRunner @@ -42,7 +42,7 @@ def __init__( def preprocess_batch( self, - batches: list[LanguageModelBatchNew], + batch: LanguageModelPreprocessedBatch, *, phase: PhaseType, iteration: int, @@ -55,17 +55,17 @@ def preprocess_batch( reference_preprocessed_batches = {} for name, reference_model in self._reference_models.items(): reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( - batches, + batch, phase=PhaseType.inference, iteration=iteration, ) preprocessed = [] presents = None - for micro_sequence_index, batch in enumerate(batches): + for micro_sequence_index, micro_sequence in enumerate(batch.micro_batches): pasts = presents - presents = None if micro_sequence_index == len(batches) - 1 else [] - batch.to_device_(self._distributed.device) + presents = None if micro_sequence_index == len(batch) - 1 else [] + micro_sequence.to_device_(self._distributed.device) kwargs: dict[str, typing.Any] = { LanguageModelKwargs.phase: phase, AttentionKwargs.past_key_values: pasts, @@ -74,22 +74,22 @@ def preprocess_batch( LanguageModelKwargs.device: self._distributed.device, LanguageModelKwargs.output_hidden_states: [], LanguageModelKwargs.hidden_states: {}, - LanguageModelKwargs.token_dim: batch.token_dim, - LanguageModelKwargs.hidden_token_dim: batch.hidden_token_dim, - LanguageModelKwargs.sequence_k_dim: batch.sequence_k_dim, - LanguageModelKwargs.num_tokens: batch.num_tokens, - LanguageModelKwargs.sequence_length: batch.sequence_length, - LanguageModelKwargs.sequence_lengths: batch.document_lengths, - LanguageModelKwargs.labels: batch.labels, - LanguageModelKwargs.loss_mask: batch.prediction_masks, - AttentionKwargs.cu_seqlens_q: batch.cumulative_lengths_q, - AttentionKwargs.cu_seqlens_k: batch.cumulative_lengths_k, - AttentionKwargs.max_seqlen_q: batch.max_length_q, - AttentionKwargs.max_seqlen_k: batch.max_length_k, - LanguageModelKwargs.seq_idx: batch.document_index, - LanguageModelKwargs.position_ids: batch.position_index, - LanguageModelKwargs.chosen_spans: batch.chosen_spans, - LanguageModelKwargs.rejected_spans: batch.rejected_spans, + LanguageModelKwargs.token_dim: micro_sequence.token_dim, + LanguageModelKwargs.hidden_token_dim: micro_sequence.hidden_token_dim, + LanguageModelKwargs.sequence_k_dim: micro_sequence.sequence_k_dim, + LanguageModelKwargs.num_tokens: micro_sequence.num_tokens, + LanguageModelKwargs.sequence_length: micro_sequence.sequence_length, + LanguageModelKwargs.sequence_lengths: micro_sequence.document_lengths, + LanguageModelKwargs.labels: micro_sequence.labels, + LanguageModelKwargs.loss_mask: micro_sequence.prediction_masks, + AttentionKwargs.cu_seqlens_q: micro_sequence.cumulative_lengths_q, + AttentionKwargs.cu_seqlens_k: micro_sequence.cumulative_lengths_k, + AttentionKwargs.max_seqlen_q: micro_sequence.max_length_q, + AttentionKwargs.max_seqlen_k: micro_sequence.max_length_k, + LanguageModelKwargs.seq_idx: micro_sequence.document_index, + LanguageModelKwargs.position_ids: micro_sequence.position_index, + LanguageModelKwargs.chosen_spans: micro_sequence.chosen_spans, + LanguageModelKwargs.rejected_spans: micro_sequence.rejected_spans, } if extra_kwargs is not None: Assert.empty(kwargs.keys() & extra_kwargs.keys()) @@ -112,7 +112,7 @@ def preprocess_batch( for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() } self.preprocess(kwargs) - preprocessed.append((batch.tokens, kwargs)) + preprocessed.append((micro_sequence.tokens, kwargs)) return preprocessed diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index df7f78643..e65556501 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -1,7 +1,7 @@ import logging import typing -from fast_llm.batch.config import LanguageModelBatchPreprocessingConfig +from fast_llm.data.batch.language_model import LanguageModelBatchPreprocessingConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.config import SamplingParameters from fast_llm.engine.distributed.config import PhaseType diff --git a/fast_llm/models/multimodal/huggingface.py b/fast_llm/models/multimodal/huggingface.py index 8b0859992..12491937f 100644 --- a/fast_llm/models/multimodal/huggingface.py +++ b/fast_llm/models/multimodal/huggingface.py @@ -5,7 +5,6 @@ import transformers.modeling_outputs from fast_llm.data.preprocessing.image_patch import ImagePatchConfig -from fast_llm.data.sample.patch import PatchBatch from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM from fast_llm.models.multimodal.config import MultiModalModelConfig diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 87d8f3310..7eb784148 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -4,7 +4,6 @@ import torch from fast_llm.core.distributed import all_gather_scalar -from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner diff --git a/tests/data/common.py b/tests/data/common.py index 7ec4a9018..26aeda845 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -5,6 +5,7 @@ import torch from fast_llm.config import NoAutoValidate +from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset @@ -12,6 +13,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset +from fast_llm.data.document.language_model import LanguageModelBatch from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -86,17 +88,27 @@ def get_test_data_and_compare_samples( assert "sampling" not in config config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) - data = GPTData(GPTDataConfig.from_dict(config), distributed_config) - data.setup(distributed, sampling_parameters, preprocessing, cache_directory) with NoAutoValidate(): batch_config = GPTBatchConfig(batch_size=1, sequence_length=sequence_length) batch_config.setup(distributed_config) batch_config.validate() + preprocessing = LanguageModelBatchPreprocessingConfig.from_dict( + preprocessing, {"batch": batch_config, "type": None} + ) + data = GPTData(GPTDataConfig.from_dict(config), distributed_config) + data.setup( + distributed, + sampling_parameters, + {dataset_name: preprocessing for dataset_name in samples_per_dataset}, + cache_directory, + ) tokens = { phase: torch.stack( [ - batch.tokens.tokens[0] - for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0) + batch.tokens.tokens + for batch in data.get_iterator( + batch_config, phase, consumed_samples=0, num_workers=0, preprocess=False + ) ] ) for phase, samples in samples_per_dataset.items() @@ -128,7 +140,12 @@ def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list # for i in range(len(expected_samples)): # print(i, sampled[i].tokens.tokens.tolist()) Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal(torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]), expected_samples) + Assert.all_equal( + torch.stack( + [LanguageModelBatch.from_documents(sampled[i]).tokens.tokens for i in range(len(expected_samples))] + ), + expected_samples, + ) def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): @@ -163,7 +180,9 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] for index in range(sampled._parameters.num_samples) ] - token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) + token_ids = torch.stack( + [LanguageModelBatch.from_documents(sampled[i]).tokens.tokens for i in range(len(sampled))] + ).to(torch.int64) Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: Assert.all_equal(token_ids, expected_samples) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 989e99b24..b49a44b2a 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,7 +4,7 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -114,7 +114,7 @@ def test_gpt_blended(): "datasets": [config, alt_config], "weights": [0.75, 0.25], }, - BlendedDatasetConfig[LanguageModelSample], + BlendedDatasetConfig[LanguageModelDocument], ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) @@ -142,7 +142,7 @@ def test_gpt_blended_mixed(): ], "weights": [0.6, 0.4], }, - BlendedDatasetConfig[LanguageModelSample], + BlendedDatasetConfig[LanguageModelDocument], ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 19539cc8c..cf75ea413 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,7 +1,7 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig +from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig -from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset_tokens, compare_sampled_dataset, @@ -30,7 +30,7 @@ def test_gpt_concatenate(): memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() dataset = get_dataset_config( dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, - ConcatenatedDatasetConfig[LanguageModelSample], + ConcatenatedDatasetConfig[LanguageModelDocument], ).build(LanguageModelPreprocessingConfig()) compare_indexed_dataset_tokens( dataset, diff --git a/tests/data/test_image_patch.py b/tests/data/test_image_patch.py index 747f6a737..8d5d7301c 100644 --- a/tests/data/test_image_patch.py +++ b/tests/data/test_image_patch.py @@ -8,9 +8,8 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.dataset.memmap.memmap import MemmapDataset +from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.utils import Assert from tests.data.common import get_dataset_config from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_SAMPLES, COMMON_DATASET_TEXT @@ -126,7 +125,7 @@ def _get_image_tokens( @pytest.mark.parametrize("image_end_token", (None, 132)) def test_gpt_data_with_image_patches(image_break_token, image_end_token): _, config, hf_path, preprocessing = get_test_dataset_with_image_patches(image_break_token, image_end_token) - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + dataset: MemmapDataset[LanguageModelDocument] = get_dataset_config(config, GPTDatasetFromFileConfig).build( preprocessing ) test_index = 2 * (image_break_token is not None) + (image_end_token is not None) @@ -174,11 +173,9 @@ def test_gpt_data_with_image_patches(image_break_token, image_end_token): def test_gpt_data_with_missing_image_patches(): path, config, hf_path, _ = get_common_test_dataset() _, _, _, preprocessing = get_test_dataset_with_image_patches(config_only=True) - LanguageModelPreprocessingConfig - with pytest.warns(match="The model uses image patches"): - dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build(preprocessing) + dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build(preprocessing) for index in COMMON_DATASET_SAMPLES: document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), COMMON_DATASET_SAMPLES[index]) - Assert.eq(document.image_patches.patches.shape, (0,) + preprocessing.image_patches.patch_shape) + Assert.none(document.image_patches) diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py index 30047163a..a963170fd 100644 --- a/tests/data/test_loss_masking_spans.py +++ b/tests/data/test_loss_masking_spans.py @@ -3,9 +3,9 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.memmap.memmap import MemmapDataset +from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.preprocessing.tokenizer import TokenizerConfig -from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.data.common import get_dataset_config from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_SAMPLES, COMMON_DATASET_TEXT @@ -39,7 +39,7 @@ @pytest.mark.slow def test_gpt_data_with_loss_masking_spans(): _, config, hf_path, preprocessing = get_test_dataset_with_loss_masking_spans() - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + dataset: MemmapDataset[LanguageModelDocument] = get_dataset_config(config, GPTDatasetFromFileConfig).build( preprocessing ) @@ -83,10 +83,9 @@ def test_gpt_data_with_loss_masking_spans(): def test_gpt_data_with_missing_loss_masking_spans(): path, config, hf_path, _ = get_common_test_dataset() _, _, _, preprocessing = get_test_dataset_with_loss_masking_spans(config_only=True) - with pytest.warns(match="The model uses loss masking spans"): - dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build(preprocessing) + dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build(preprocessing) for index in COMMON_DATASET_SAMPLES: document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), COMMON_DATASET_SAMPLES[index]) - Assert.eq(document.loss_masking_spans.ranges, []) + Assert.none(document.loss_masking_spans) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py index 7ba4e04ac..ef12e3837 100644 --- a/tests/data/test_preference_spans.py +++ b/tests/data/test_preference_spans.py @@ -5,9 +5,9 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.memmap.memmap import MemmapDataset +from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.preprocessing.tokenizer import TokenizerConfig -from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.data.common import get_dataset_config from tests.data.test_preparator import COMMON_DATASET_LENGTH @@ -41,7 +41,7 @@ @pytest.mark.slow def test_gpt_data_with_spans(): _, config, hf_path, preprocessing = get_test_dataset_with_preference_spans() - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + dataset: MemmapDataset[LanguageModelDocument] = get_dataset_config(config, GPTDatasetFromFileConfig).build( preprocessing ) diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index f4f6fab82..ab5942c20 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -3,9 +3,10 @@ import datasets import pytest -from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig, SamplingParameters +from fast_llm.data.dataset.config import BlendedDatasetConfig, SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.memmap.config import MemmapDatasetConfig +from fast_llm.data.dataset.memmap.memmap import MemmapDataset from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index f28c9cce2..737609994 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -5,8 +5,8 @@ from fast_llm.data.dataset.config import SamplingParameters, ShufflingType from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.data.sample.token import TokenSample +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.document.token import TokenDocument from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -40,7 +40,7 @@ def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() sampled = get_dataset_config( - dataset_config := config, GPTDatasetFromFileConfig[LanguageModelSample] + dataset_config := config, GPTDatasetFromFileConfig[LanguageModelDocument] ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) @@ -54,17 +54,19 @@ def test_gpt_sampled(): ) -class SimpleGPTIndexedDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): +class SimpleGPTIndexedDataset[DocumentType: LanguageModelDocument](IndexedDataset[DocumentType]): # TODO: worth adding to the main codebase? def __init__(self, samples): self._samples = samples def get_document( self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: + ) -> DocumentType: if end is None: end = len(self._samples[index]) - return LanguageModelSample(TokenSample(torch.tensor(self._samples[index][begin:end], dtype=torch.int64))) + return LanguageModelDocument( + tokens=TokenDocument(tokens=torch.tensor(self._samples[index][begin:end], dtype=torch.int64)) + ) def __len__(self) -> int: return len(self._samples) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 54263b8e2..ddf16acf1 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,6 +1,6 @@ from fast_llm.data.dataset.config import DatasetSliceConfig from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.document.language_model import LanguageModelDocument from tests.data.common import ( compare_indexed_dataset_tokens, get_dataset_config, @@ -36,7 +36,7 @@ def test_gpt_slice(): # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": memmap_config, "begin": 0.025, "end": 0.1}, - DatasetSliceConfig[LanguageModelSample], + DatasetSliceConfig[LanguageModelDocument], ).build(preprocessing) compare_indexed_dataset_tokens(dataset, 75, 3399, {i - 25: sample for i, sample in COMMON_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index e29050b28..011bb5aea 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -11,14 +11,15 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import MemmapDatasetConfig, SampledDatasetConfig +from fast_llm.data.dataset.config import SampledDatasetConfig from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset +from fast_llm.data.dataset.memmap.config import MemmapDatasetConfig from fast_llm.data.dataset.sampled import logger +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.document.token import TokenDocument from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig -from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig @@ -48,8 +49,8 @@ def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() samples = [ - LanguageModelSample( - TokenSample((tokenizer.tokenize(document["text"]) % MODEL_TEST_VOCAB_SIZE).to(torch.uint16)) + LanguageModelDocument( + TokenDocument((tokenizer.tokenize(document["text"]) % MODEL_TEST_VOCAB_SIZE).to(torch.uint16)) ) for document in hf_dataset ] @@ -116,14 +117,14 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co @config_class(dynamic_type={SampledDatasetConfig: "megatron"}) -class MegatronDatasetConfig[SampleType: LanguageModelSample](MemmapDatasetConfig[SampleType]): +class MegatronDatasetConfig[DocumentType: LanguageModelDocument](MemmapDatasetConfig[DocumentType]): _abstract: typing.ClassVar[bool] = False path: str = Field( desc="Dataset path (prefix).", hint=FieldHint.core, ) - def build(self, preprocessing: PreprocessingConfig) -> "LegacyMemmapDataset[SampleType]": + def build(self, preprocessing: PreprocessingConfig) -> "LegacyMemmapDataset[DocumentType]": return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path, preprocessing) @@ -135,7 +136,7 @@ def sample(self, sampling: GPTSamplingData) -> "MegatronSampledIndexedDataset": def write_dataset( cls, prefix: pathlib.Path | str, - documents: typing.Iterable[LanguageModelSample], + documents: typing.Iterable[LanguageModelDocument], ) -> None: # Initialize metadata dtype = None @@ -192,7 +193,7 @@ def write_dataset( idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C")) -class MegatronSampledIndexedDataset(SampledDataset): +class MegatronSampledIndexedDataset[DocumentType: LanguageModelDocument](SampledDataset[DocumentType]): """ A GPT sampled dataset that exactly matches Megatron-LM, for testing purposes. Minimalistic implementation, implements only the required features. @@ -231,20 +232,18 @@ def __init__( def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx: int) -> typing.Any: + def __getitem__(self, idx: int) -> list[DocumentType]: shuffled_idx = self._shuffle_idx[idx] doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - return LanguageModelSample.from_documents( - [ - self._indexed_dataset.get_document( - self._doc_idx[doc].item(), - begin=(doc == doc_f) * offset_f, - end=offset_l + 1 if doc == doc_l else None, - ) - for doc in range(doc_f, doc_l + 1) - ] - ) + return [ + self._indexed_dataset.get_document( + self._doc_idx[doc].item(), + begin=(doc == doc_f) * offset_f, + end=offset_l + 1 if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] @property def name(self) -> str: diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py index cdf2295e0..f0af94256 100644 --- a/tests/test_loss_mask.py +++ b/tests/test_loss_mask.py @@ -9,9 +9,6 @@ import torch from fast_llm.config import NoAutoValidate -from fast_llm.data.sample.language_model import LanguageModelBatch -from fast_llm.data.sample.range import RangeBatch -from fast_llm.data.sample.token import TokenBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTBatchConfig, GPTModelConfig From 1697a482fc23e7a71bc03f02d505bd8711136409 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 19 Feb 2026 19:28:10 -0500 Subject: [PATCH 3/4] stuff --- fast_llm/data/batch/config.py | 66 +++- fast_llm/data/batch/language_model.py | 33 +- fast_llm/data/data/abstract.py | 35 +- fast_llm/data/data/gpt/data.py | 99 +++--- fast_llm/data/dataset/config.py | 7 +- fast_llm/data/dataset/gpt/config.py | 4 +- fast_llm/data/dataset/sampled.py | 2 +- fast_llm/engine/base_model/base_model.py | 9 +- fast_llm/engine/config_utils/run.py | 6 +- fast_llm/engine/distributed/config.py | 9 +- fast_llm/engine/distributed/distributed.py | 3 +- fast_llm/engine/evaluation/config.py | 15 +- fast_llm/engine/evaluation/evaluator.py | 309 +++++------------ .../engine/evaluation/lm_eval/evaluator.py | 52 +-- .../evaluation/lm_eval/fast_llm_wrapper.py | 4 +- fast_llm/engine/multi_stage/fast_llm_model.py | 10 + fast_llm/engine/schedule/config.py | 14 +- fast_llm/engine/schedule/runner.py | 3 +- fast_llm/engine/schedule/schedule.py | 44 +-- fast_llm/engine/training/config.py | 18 +- fast_llm/engine/training/trainer.py | 316 ++++-------------- fast_llm/layers/language_model/head.py | 7 +- fast_llm/layers/language_model/loss/dpo.py | 4 + fast_llm/layers/language_model/loss/loss.py | 17 +- fast_llm/logging.py | 2 - fast_llm/models/gpt/config.py | 10 - fast_llm/models/gpt/model.py | 8 +- fast_llm/models/gpt/trainer.py | 28 -- fast_llm/models/multimodal/trainer.py | 3 +- tests/data/common.py | 21 +- 30 files changed, 388 insertions(+), 770 deletions(-) diff --git a/fast_llm/data/batch/config.py b/fast_llm/data/batch/config.py index a3d192bae..360a07fb6 100644 --- a/fast_llm/data/batch/config.py +++ b/fast_llm/data/batch/config.py @@ -1,10 +1,11 @@ +import abc import dataclasses import functools import logging import typing -from fast_llm.config import Field, config_class -from fast_llm.data.document.abstract import Document +from fast_llm.config import Configurable, Field, FieldUpdate, config_class +from fast_llm.data.document.abstract import Batch, Document from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImagePatchConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig @@ -22,15 +23,18 @@ @config_class() class BatchPreprocessingConfig(PreprocessingConfig): - pass + batch: BatchConfig = Field() + phase: PhaseType = Field(default=PhaseType.inference) + + def get_batch_meta(self) -> "PreprocessedBatch": + raise NotImplementedError() @config_class() -class LanguageModelBatchPreprocessingConfig(LanguageModelPreprocessingConfig): +class LanguageModelBatchPreprocessingConfig(LanguageModelPreprocessingConfig, BatchPreprocessingConfig): _abstract = False # TODO: Duplicate `use_loss_masking_spans`, `use_preference_spans` - batch: GPTBatchConfig = Field() - phase: PhaseType = Field(default=PhaseType.inference) + batch: GPTBatchConfig = FieldUpdate() predicted_tokens: int = Field(default=1) return_cumulative_sequence_lengths: bool = Field(default=False) return_max_sequence_lengths: bool = Field(default=False) @@ -43,10 +47,28 @@ def _validate(self) -> None: Assert.custom(isinstance, self.image_patches, (ImagePatchConfig, NullPreprocessingConfig)) Assert.custom(isinstance, self.tokenizer, (TokenizerConfig, NullPreprocessingConfig)) + def get_batch_meta(self) -> "PreprocessedBatch": + from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch + from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument + from fast_llm.data.document.token import TokenDocument + + device = torch.device("meta") + tokens = torch.empty(self.total_length, dtype=torch.int64, device=device) + batch = LanguageModelBatch.from_documents([LanguageModelDocument(tokens=TokenDocument(tokens=tokens))]) + return LanguageModelPreprocessedBatch.from_batch(batch, config=self, device=device) + @functools.cached_property def use_image_patches(self) -> bool: return isinstance(self.image_patches, ImagePatchConfig) + @functools.cached_property + def total_length(self) -> int: + return self.batch.sequence_length + self.predicted_tokens + + @functools.cached_property + def distributed(self) -> DistributedConfig: + return self.batch.distributed + def check_compatibility(self, preprocessing: typing.Self) -> None: Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? @@ -64,21 +86,37 @@ class MicroBatch: pass -@dataclasses.dataclass -class PreprocessedBatch: - micro_batches: list[MicroBatch] +class PreprocessedBatch[ConfigType: BatchPreprocessingConfig, MicroBatchType: MicroBatch](Configurable[ConfigType]): + def __init__(self, config: ConfigType, micro_batches: list[MicroBatchType]): + super().__init__(config) + self._micro_batches = micro_batches + @property + def micro_batches(self) -> list[MicroBatch]: + return self._micro_batches -@config_class(registry=True) -class BatchPreprocessingConfig(PreprocessingConfig): - batch: BatchConfig = Field() + def __len__(self) -> int: + return len(self._micro_batches) + + def __getitem__(self, idx: int) -> MicroBatchType: + return self._micro_batches[idx] @classmethod + @abc.abstractmethod def from_documents( cls, - config: BatchPreprocessingConfig, - distributed_config: DistributedConfig, documents: list[Document], + config: BatchPreprocessingConfig, + device: "torch.device | None" = None, + ) -> typing.Self: + pass + + @classmethod + @abc.abstractmethod + def from_batch( + cls, + batch: Batch, + config: BatchPreprocessingConfig, device: "torch.device | None" = None, ) -> typing.Self: pass diff --git a/fast_llm/data/batch/language_model.py b/fast_llm/data/batch/language_model.py index b0f67fc1c..06bc90e37 100644 --- a/fast_llm/data/batch/language_model.py +++ b/fast_llm/data/batch/language_model.py @@ -6,7 +6,7 @@ from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig, MicroBatch, PreprocessedBatch from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedDimNames @dataclasses.dataclass @@ -46,30 +46,27 @@ def to_device_(self, device: torch.device): @dataclasses.dataclass -class LanguageModelPreprocessedBatch(PreprocessedBatch): - micro_batches: list[LanguageModelMicroBatch] +class LanguageModelPreprocessedBatch[ + ConfigType: LanguageModelBatchPreprocessingConfig, MicroBatchType: LanguageModelMicroBatch +](PreprocessedBatch[ConfigType, MicroBatchType]): + def __init__(self, config: LanguageModelBatchPreprocessingConfig, micro_batches: list[MicroBatchType]): + super().__init__(config, micro_batches) @classmethod def from_documents( cls, documents: list[LanguageModelDocument], - *, - config: LanguageModelBatchPreprocessingConfig, - distributed_config: DistributedConfig, + config: ConfigType, device: torch.device | None = None, ) -> typing.Self: - batch = LanguageModelBatch.from_documents( - documents, pad_to_size=config.batch.sequence_length + config.predicted_tokens - ) - return cls.from_batch(batch, config=config, distributed_config=distributed_config, device=device) + batch = LanguageModelBatch.from_documents(documents, pad_to_size=config.total_length) + return cls.from_batch(batch, config=config, device=device) @classmethod def from_batch( cls, batch: LanguageModelBatch, - *, - config: LanguageModelBatchPreprocessingConfig, - distributed_config: DistributedConfig, + config: ConfigType, device: torch.device | None = None, ) -> typing.Self: if device is None: @@ -79,21 +76,21 @@ def from_batch( token_dim = TensorDim( "token", config.batch.micro_sequence_length, - distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), + config.distributed.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_token_dim = ( ( "token_tp", token_dim.global_size, - distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), + config.distributed.get_distributed_dim(DistributedDimNames.tensor_and_data), ) - if distributed_config.sequence_tensor_parallel + if config.distributed.sequence_tensor_parallel else token_dim ) micro_batches = [] for micro_sequence_index, sequence_k_past in enumerate( range( - token_dim.size * distributed_config.sequence_data_rank, + token_dim.size * config.distributed.sequence_data_rank, config.batch.sequence_length, token_dim.global_size, ) @@ -147,4 +144,4 @@ def from_batch( micro_batch.prediction_masks.append(labels > 0) micro_batches.append(micro_batch) - return LanguageModelPreprocessedBatch(micro_batches=micro_batches) + return cls(micro_batches=micro_batches, config=config) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index c5400b6c7..87b6ddd17 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -3,13 +3,10 @@ import typing from fast_llm.config import Configurable -from fast_llm.data.batch.config import PreprocessedBatch +from fast_llm.data.batch.config import BatchPreprocessingConfig, PreprocessedBatch from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.distributed.distributed import Distributed @@ -17,32 +14,28 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): _distributed: "Distributed" - _sampling_parameters: dict[str, SamplingParameters] - _preprocessing: dict[str, PreprocessingConfig] + # _sampling_parameters: dict[str, SamplingParameters] + # _preprocessing: dict[str, PreprocessingConfig] _cache_directory: pathlib.Path | None + _is_setup: bool = False def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: super().__init__(config) self._distributed_config = distributed_config # TODO: Improve interface - def setup( - self, - distributed: "Distributed", - sampling_parameters: dict[str, SamplingParameters], - preprocessing: dict[str, PreprocessingConfig], - cache_directory: pathlib.Path, - timeout: float | None = None, - ) -> None: - Assert.eq(sampling_parameters.keys(), preprocessing.keys()) - self._distributed = distributed - self._sampling_parameters = sampling_parameters - self._preprocessing = preprocessing + def setup(self, cache_directory: pathlib.Path) -> None: self._cache_directory = cache_directory + self._is_setup = True - @property - def distributed(self): - return self._distributed + @abc.abstractmethod + def sample_dataset( + self, + dataset_name: str, + config: BatchPreprocessingConfig, + num_samples: int, + ) -> None: + pass @abc.abstractmethod def get_iterator( diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index ff1fbd3bc..e15d95e90 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,13 +1,11 @@ import functools import logging -import pathlib import typing import warnings import torch import torch.utils.data -from fast_llm.core.distributed import safe_barrier from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch from fast_llm.data.data.abstract import Data @@ -20,7 +18,6 @@ from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert @@ -33,9 +30,8 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ _datasets: dict[str, SampledDataset] - _sampling_parameters: dict[str, SamplingParameters] + # _sampling_parameters: dict[str, SamplingParameters] _preprocessing: dict[str, LanguageModelBatchPreprocessingConfig] - _is_setup: bool = False def __init__( self, @@ -47,56 +43,46 @@ def __init__( Should be `setup` before use. """ super().__init__(config, distributed_config) + self._datasets = {} + self._preprocessing = {} - def setup( + def sample_dataset( self, - distributed: "Distributed", - sampling_parameters: dict[str, SamplingParameters], - preprocessing: dict[str, LanguageModelBatchPreprocessingConfig], - cache_directory: pathlib.Path, - timeout: float | None = None, + dataset_name: str, + config: LanguageModelBatchPreprocessingConfig, + num_samples: int, ) -> None: - """ - Load the datasets, and prepare or load the samplings. - This may take a while and a significant amount of cpu memory. - """ - super().setup(distributed, sampling_parameters, preprocessing, cache_directory) - - # Check and raise an error if a used dataset is not defined. - for dataset_name in self._sampling_parameters.keys(): - if dataset_name not in self._config.datasets: - raise ValueError(f"Dataset {dataset_name} not found.") - - # Check and warn if there are defined datasets that are not used. - unused_datasets = self._config.datasets.keys() - self._sampling_parameters.keys() - if unused_datasets: - warnings.warn( - f"The following datasets are defined but not used: {', '.join(unused_datasets)}. " - "Ensure this is intentional, or update the configuration accordingly." - ) + assert self._is_setup + Assert.gt(num_samples, 0) + if dataset_name not in self._config.datasets: + raise ValueError(f"Dataset {dataset_name} not found.") + if dataset_name in self._datasets: + raise ValueError(f"Dataset {dataset_name} is already sampled.") - log_main_rank(f"Preparing dataset. This may take several minutes.") + log_main_rank(f"Sampling dataset {dataset_name}. This may take several minutes.") if self._cache_directory is None: # TODO: Avoid this - warnings.warn(f"Using the dataset directory for the index cache.") + warnings.warn(f"The index cache will be saved in the dataset directory.") - self._datasets = {} - for dataset_name, sampling_parameters in self._sampling_parameters.items(): - if sampling_parameters.num_samples > 0: - sampling = GPTSamplingData( - config=self._config.sampling, - parameters=sampling_parameters, - preprocessing=self._preprocessing[dataset_name], - cache_directory=self._cache_directory, - distributed=distributed, - dataset_name=dataset_name, - ) - dataset = self._config.datasets[dataset_name].build_and_sample(sampling) - self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) - - safe_barrier(self._distributed.world_group, "data_preparation", timeout) - self._is_setup = True + sampling_parameters = SamplingParameters( + sequence_length=config.batch.sequence_length, + num_samples=num_samples, + truncate_documents=config.batch.truncate_documents, + extra_tokens=config.predicted_tokens, + ) + + sampling = GPTSamplingData( + config=self._config.sampling, + parameters=sampling_parameters, + preprocessing=config, + cache_directory=self._cache_directory, + distributed_config=self._distributed_config, + dataset_name=dataset_name, + ) + self._preprocessing[dataset_name] = config + dataset = self._config.datasets[dataset_name].build_and_sample(sampling) + self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) def get_iterator( self, @@ -116,8 +102,6 @@ def get_iterator( dataset_name = dataset_name.lower() Assert.incl(dataset_name, self._datasets) - sampling_parameters = self._sampling_parameters[dataset_name] - Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") return iter( @@ -126,9 +110,9 @@ def get_iterator( batch_sampler=SampledDatasetIterator( total_samples=len(self._datasets[dataset_name]), begin_index=consumed_samples, - micro_batch_size=batch_config.micro_batch_size, - data_rank=self._distributed.config.batch_data_rank, - data_parallel=self._distributed.config.batch_data_parallel, + micro_batch_size=self._preprocessing[dataset_name].batch.micro_batch_size, + data_rank=self._distributed_config.batch_data_rank, + data_parallel=self._distributed_config.batch_data_parallel, ), num_workers=num_workers, prefetch_factor=prefetch_factor, @@ -145,14 +129,7 @@ def _collate_fn( preprocess: bool = True, ) -> LanguageModelPreprocessedBatch | LanguageModelBatch: documents = [document for documents_ in documents for document in documents_] - config = self._preprocessing[dataset_name] if preprocess: - return LanguageModelPreprocessedBatch.from_documents( - documents, - config=config, - distributed_config=self._distributed_config, - ) + return LanguageModelPreprocessedBatch.from_documents(documents, self._preprocessing[dataset_name]) else: - return LanguageModelBatch.from_documents( - documents, pad_to_size=config.batch.sequence_length + config.predicted_tokens - ) + return LanguageModelBatch.from_documents(documents, self._preprocessing[dataset_name].total_length) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 1e1fece26..39844ac8b 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -11,11 +11,11 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.document.abstract import Document from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -85,8 +85,7 @@ class SamplingData: config: SamplingConfig parameters: SamplingParameters cache_directory: pathlib.Path | None - # TODO: This prevents the sampling config from being pickled in multiprocessing. - distributed: "Distributed" + distributed_config: DistributedConfig dataset_name: str preprocessing: PreprocessingConfig # Using a mutable rather than an int so it's shared with all copies made with `update`. @@ -99,7 +98,7 @@ def update_config(self, update: SamplingConfig): def get_next_rank(self) -> int: # Counter that loops over ranks to try to distribute workloads evenly between ranks. - return next(self._rank_counter()) % self.distributed.config.world_size + return next(self._rank_counter()) % self.distributed_config.world_size @config_class() diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index b66bc5445..62da794ee 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -196,7 +196,7 @@ class GPTTestSlowDatasetConfig[DocumentType: LanguageModelDocument](SampledDatas ) def build_and_sample(self, sampling: SamplingData) -> SampledDataset[DocumentType]: - assert sampling.distributed.config.world_size > 1 - if sampling.distributed.config.rank == 0: + assert sampling.distributed_config.world_size > 1 + if sampling.distributed_config.rank == 0: time.sleep(self.sleep) return GPTRandomDatasetConfig[DocumentType]().build_and_sample(sampling) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index a3b7c05a5..2ae5c693e 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -106,7 +106,7 @@ def __init__( self._yaml_path = base_path.with_suffix(".yaml") # Sample or validate the dataset of a given rank. - if sampling.distributed.config.rank == sampling.get_next_rank(): + if sampling.distributed_config.rank == sampling.get_next_rank(): self._sample() # No barrier yet to allow running in parallel. # There needs to be one before calling `__getitem__`, normally handled through `Data`. diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index f5f8dc5e7..195a1508a 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -5,6 +5,7 @@ import torch.nn from fast_llm.config import Configurable +from fast_llm.data.batch.config import PreprocessedBatch from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -174,16 +175,10 @@ def __init__( # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - @abc.abstractmethod - def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: - # TODO Remove (Move batch splitting elsewhere) - pass - @abc.abstractmethod def preprocess_batch( self, - batch: typing.Any, - preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, + batch: PreprocessedBatch, *, phase: PhaseType, iteration: int, diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index baa386337..ab6f27489 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -231,8 +231,12 @@ def __exit__(self, exc_type, exc_val: OSError, exc_tb): _run: Run | None = None +def run_exists() -> bool: + return _run is not None + + def get_run() -> Run: - assert _run is not None + assert run_exists() return _run diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index c7ab610b2..d0011fc76 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -41,10 +41,9 @@ class PhaseType(enum.StrEnum): - training = "Training" - validation = "Validation" - test = "Test" - inference = "Inference" + training = "training" + validation = "validation" + inference = "inference" @property def is_training(self) -> bool: @@ -277,7 +276,7 @@ class DistributedConfig(Config): valid_seed_shift: int = Field( default=_BIG_PRIMES[9], desc="Seed shift for extra randomness.", hint=FieldHint.optional ) - test_seed_shift: int = Field( + inference_seed_shift: int = Field( default=_BIG_PRIMES[10], desc="Seed shift for extra randomness.", hint=FieldHint.optional ) # (slower, uses more memory, mainly for debug) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 6ff9ce227..c13b40b60 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -223,8 +223,7 @@ def __init__(self, config: DistributedConfig): self._phase_seeds_shifts = { PhaseType.training: self._config.train_seed_shift, PhaseType.validation: self._config.valid_seed_shift, - PhaseType.test: self._config.test_seed_shift, - PhaseType.inference: self._config.test_seed_shift, + PhaseType.inference: self._config.inference_seed_shift, } self.set_step(0, PhaseType.training) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index df7ab0f51..f7ae62f04 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -17,8 +17,7 @@ def get_evaluator( self, name: str, batch_config: BatchConfig, - data_load_num_proc: int, - train_iters: int | None = None, + num_workers: int, ) -> "Evaluator": pass @@ -46,18 +45,15 @@ class LossEvaluatorConfig(EvaluatorConfig): valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - dataset_name: str | None = Field(default=None) - def get_evaluator( self, name: str, batch_config: BatchConfig, - data_load_num_proc: int, - train_iters: int | None = None, + num_workers: int, ) -> "LossEvaluator": from fast_llm.engine.evaluation.evaluator import LossEvaluator - return LossEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + return LossEvaluator(name, self, batch_config, num_workers) @config_class(dynamic_type={EvaluatorConfig: "lm_eval"}) @@ -113,9 +109,8 @@ def get_evaluator( self, name: str, batch_config: BatchConfig, - data_load_num_proc: int, - train_iters: int | None = None, + num_workers: int, ) -> "EvaluatorLmEval": from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator - return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + return LmEvalEvaluator(name, self, batch_config, num_workers) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index e055595bd..8a3bd7e3d 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -1,24 +1,22 @@ import abc import dataclasses -import functools import logging -import math import time import typing from fast_llm.config import Configurable from fast_llm.core.distributed import safe_barrier +from fast_llm.data.batch.config import PreprocessedBatch from fast_llm.data.data.abstract import Data -from fast_llm.engine.config_utils.run import Run, log_main_rank +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.run import get_run, log_main_rank, run_exists from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, LossEvaluatorConfig +from fast_llm.engine.evaluation.config import EvaluatorConfig, LossEvaluatorConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.engine.training.config import WandbConfig -from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics from fast_llm.utils import get_and_reset_memory_usage_mib @@ -27,284 +25,141 @@ @dataclasses.dataclass class TrainingProgress: - done: bool completed_steps: int consumed_samples: int consumed_tokens: int -@dataclasses.dataclass -class EvaluationMetrics: - metrics: dict[str, any] = dataclasses.field(default_factory=dict) - formatted_metrics: str | None = None - - -@dataclasses.dataclass -class EvaluatorSamplingParameters: - dataset_name: str - num_samples: int - - class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): _is_setup: bool = False + _multi_stage: FastLLMModel + _runner: ScheduleRunner + _data: Data + _distributed: Distributed def __init__( self, name: str, eval_config: LossEvaluatorConfig, batch_config: BatchConfig, - data_load_num_proc: int, - train_iters: int | None = None, + num_workers: int, ): super().__init__(eval_config) self._name = name self._batch_config = batch_config - self._data_load_num_proc = data_load_num_proc - self._train_iters = train_iters + self._num_workers = num_workers + @abc.abstractmethod def setup( self, - distributed: Distributed, - run: Run, multi_stage: FastLLMModel, runner: ScheduleRunner, data: Data, - phase: PhaseType, + run_count: int, ) -> None: - # TODO: check if objects passed are actually set up themselves, if appropriate - self._distributed = distributed - self._run = run self._runner = runner self._multi_stage = multi_stage + self._distributed = multi_stage.distributed self._data = data - self._phase = phase + self._is_setup = True @abc.abstractmethod def run( self, - training_progress: TrainingProgress | None = None, - run_index: int | None = None, - ) -> EvaluationMetrics: ... - - @abc.abstractmethod - def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: - """ - Returns the name and number of required samples in a dataset, - or None if the evaluation does not rely on Fast-LLM data or - if the evaluation is skipped for this run. - """ + run_index: int | None, + metrics: dict[str, typing.Any], + ) -> None: + pass class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]): + _data_iterator: typing.Iterator[PreprocessedBatch] | None = None + _loss_definitions: list[LossDef] + _schedule: Schedule + _data: Data + def setup( self, - distributed: Distributed, - run: Run, multi_stage: FastLLMModel, runner: ScheduleRunner, data: Data, - phase: PhaseType, + run_count: int, ) -> None: - super().setup(distributed, run, multi_stage, runner, data, phase) + super().setup(multi_stage, runner, data, run_count) + preprocessing_config = self._multi_stage.get_preprocessing_config(PhaseType.validation) + self._data.sample_dataset( + self._name, preprocessing_config, run_count * self._config.iterations * self._batch_config.batch_size + ) # Setup the schedule self._schedule = Schedule( + config=runner.config, multi_stage=self._multi_stage, - batch_config=self._batch_config, - schedule_config=runner.config, - distributed_config=distributed.config, + batch_meta=preprocessing_config.get_batch_meta(), + distributed_config=self._distributed.config, phase=PhaseType.validation, ) - - self._loss_defs = self._multi_stage.base_model.get_loss_definitions() - self._evaluation_iterator = None - self._is_setup = True - - def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: - return ( - None - if self._config.iterations is None - else EvaluatorSamplingParameters( - (self._name if self._config.dataset_name is None else self._config.dataset_name), - self._config.iterations * self._batch_config.batch_size, - ) - ) + self._loss_definitions = self._multi_stage.base_model.get_loss_definitions() + self._data_iterator = None def run( self, - training_progress: TrainingProgress | None = None, - run_index: int | None = None, - ) -> EvaluationMetrics: + run_index: int, + metrics: dict[str, typing.Any], + ) -> None: assert self._is_setup - if run_index is None: - run_index = 0 - - metrics = {} - - if self._evaluation_iterator is None: - self._evaluation_iterator = self._get_data_iterator(self._get_completed_evaluation_steps(run_index)) - # TODO: formatting metric category as Validation.evaluation_dataset_name - # maybe format each metric with evaluation_dataset_name prefix instead? - # TODO: setting performance metrics per evaluation dataset - # maybe to set aggregate performance metrics for all evaluations datasets? - phase = PhaseType.validation - metric_key = f"{phase.value}.{self._name}" - metrics[metric_key] = self._evaluate_loss( - data_iterator=self._evaluation_iterator, - phase=phase, - num_iters=self._config.iterations, - begin_iter=self._get_completed_evaluation_steps(run_index), - completed_steps=None if training_progress is None else training_progress.completed_steps, - ) - - if self._train_iters is not None: - metrics[metric_key]["train_iters"] = self._train_iters - - if training_progress is not None: - metrics[metric_key]["iteration"] = training_progress.completed_steps - metrics[metric_key]["consumed_samples"] = training_progress.consumed_samples - metrics[metric_key]["consumed_tokens"] = training_progress.consumed_tokens - - formatted_metrics = format_metrics( - metrics[metric_key], - self._loss_defs, - phase, - dataset_name=self._name, - ) - - return EvaluationMetrics(metrics, formatted_metrics) - - def _evaluate_loss( - self, - *, - data_iterator: typing.Iterator, - phase: PhaseType, - num_iters: int, - completed_steps: int | None, - begin_iter: int = 0, - ) -> dict[str, float | int]: - full_phase_name = f"{phase.value}_{self._name}" - safe_barrier(self._distributed.world_group, f"{full_phase_name} begin") + completed_evaluation_steps = max(0, run_index - 1) * self.config.iterations + + if self._data_iterator is None: + self._data.get_iterator( + self._batch_config, + self._name, + consumed_samples=completed_evaluation_steps * self._batch_config.batch_size, + num_workers=self._num_workers, + ) + safe_barrier(self._distributed.world_group, f"{PhaseType.validation} {self._name} begin") begin_time = time.perf_counter() - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} - for iter_ in range(num_iters): - iter_losses, _, _ = self._runner.run_step(data_iterator, self._schedule, iteration=begin_iter + iter_) + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions} + for iter_ in range(self._config.iterations): + iter_losses, _, _ = self._runner.run_step( + self._data_iterator, self._schedule, iteration=completed_evaluation_steps + iter_ + ) for name, value in iter_losses.items(): total_losses[name] += value - tensor_save_name = ( - f"{full_phase_name}_{iter_}" - if completed_steps is None - else f"{full_phase_name}_{completed_steps}_{iter_}" - ) - self._run.save_logged_tensors(tensor_save_name) + if run_exists(): + get_run().save_logged_tensors( + f"{PhaseType.validation}_{self._name}_{metrics.get("completed_steps",run_index)}" + ) safe_barrier( self._distributed.world_group, - f"{full_phase_name} end", + f"{PhaseType.validation} {self._name} end", ) - end_time = time.perf_counter() - time_per_iteration = (end_time - begin_time) / num_iters - - model_compute, hardware_compute = self._schedule.compute_usage - model_tflops = math.nan if model_compute is None else model_compute / time_per_iteration - hardware_tflops = math.nan if hardware_compute is None else hardware_compute / time_per_iteration - # TODO add other relevant eval metrics - metrics = { - "batch_size": self._batch_config.batch_size, - **{name: (value / num_iters) for name, value in total_losses.items()}, - "step_time_ms": time_per_iteration * 1000, - "model_tflops": model_tflops, - "hardware_tflops": hardware_tflops, - "tokens_per_sec_per_gpu": ( - (self._batch_config.sequence_length * self._batch_config.batch_size) - / self._schedule._distributed_config.world_size - / time_per_iteration - ), - **get_and_reset_memory_usage_mib(), - } - return metrics - - def _get_completed_evaluation_steps(self, run_index: int) -> int: - # Number of evaluations steps performed before the current step - return max(0, run_index - 1) * self.config.iterations - - def _get_data_iterator( - self, completed_steps: int = 0, prefetch_factor: int | None = None - ) -> typing.Iterator[typing.Any]: - return self._data.get_iterator( - self._batch_config, - self._name, - consumed_samples=completed_steps * self._batch_config.batch_size, - num_workers=self._data_load_num_proc, - prefetch_factor=prefetch_factor, + time_per_iteration = (time.perf_counter() - begin_time) / self._config.iterations + + metrics.update( + { + "batch_size": self._batch_config.batch_size, + **{name: (value / self._config.iterations) for name, value in total_losses.items()}, + "step_time_ms": time_per_iteration * 1000, + **self._schedule.get_compute_metrics(time_per_iteration), + "tokens_per_sec_per_gpu": ( + (self._batch_config.sequence_length * self._batch_config.batch_size) + / self._distributed.config.world_size + / time_per_iteration + ), + **get_and_reset_memory_usage_mib(), + } ) - @functools.cached_property - def compute_usage(self) -> tuple[int | None, int | None]: - return self._schedule.get_compute_usage(hardware=False), self._schedule.get_compute_usage(hardware=True) - - -# NOTE: This is not a standalone runnable; it's a submodule of Trainer used for code encapsulation. -class EvaluatorRunner: - _is_setup: bool = False - - def __init__( - self, - evaluator_configs: dict[str, EvaluatorConfigBase], - batch_config: BatchConfig, - data_load_num_proc: int, - train_iters: int | None = None, - wandb_config: WandbConfig | None = None, - ): - self._wandb_config = wandb_config - self._evaluators = [ - eval_config.get_evaluator(name, batch_config, data_load_num_proc, train_iters) - for name, eval_config in evaluator_configs.items() - ] - - def setup( - self, - distributed: Distributed, - run: Run, - multi_stage: FastLLMModel, - runner: ScheduleRunner, - data: Data, - wandb: Wandb, - phase: PhaseType, - ) -> None: - self._wandb = wandb - for evaluator in self._evaluators: - evaluator.setup(distributed, run, multi_stage, runner, data, phase) - self._is_setup = True - - def get_sampling_parameters(self) -> list[EvaluatorSamplingParameters]: - return [ - sampling_params - for sampling_params in (evaluator.get_sampling_parameters() for evaluator in self._evaluators) - if sampling_params is not None - ] - - def run( - self, - metrics: dict[str:any], - training_progress: TrainingProgress | None = None, - ): - assert self._is_setup - formatted_metrics = [] - for evaluator in self._evaluators: - evaluation_metrics = evaluator.run(training_progress) - if len(evaluation_metrics.metrics) == 0: - continue - for k, v in evaluation_metrics.metrics.items(): - metrics[k] = v - if evaluation_metrics.formatted_metrics is not None: - formatted_metrics.append(evaluation_metrics.formatted_metrics) - - if len(formatted_metrics) > 0: - formatted_metrics = "\n".join(formatted_metrics) - log_main_rank(formatted_metrics) - if self._wandb_config is not None and self._wandb_config.alert.enabled( - 0 if training_progress is None else training_progress.completed_steps - ): - self._wandb.alert("Validation results", formatted_metrics, "INFO") + log_main_rank( + "\n".join( + format_metrics( + metrics, + self._loss_definitions, + PhaseType.validation, + dataset_name=self._name, + ) + ) + ) diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 5bfb544ed..d03f87a24 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -4,38 +4,26 @@ import typing from fast_llm.data.data.abstract import Data -from fast_llm.engine.config_utils.run import Run -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.config import LmEvalEvaluatorConfig -from fast_llm.engine.evaluation.evaluator import ( - EvaluationMetrics, - Evaluator, - EvaluatorSamplingParameters, - TrainingProgress, -) +from fast_llm.engine.evaluation.evaluator import Evaluator +from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper +from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner -if typing.TYPE_CHECKING: - from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper - from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM - logger = logging.getLogger(__name__) class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]): - _hf_model: "HuggingfaceBaseModelForCausalLM" = None - _flm_wrapper: "FastLLMLmEvalWrapper" = None + _hf_model: HuggingfacePreTrainedModel + _flm_wrapper: FastLLMLmEvalWrapper def setup( self, - distributed: Distributed, - run: Run, multi_stage: FastLLMModel, runner: ScheduleRunner, data: Data, - phase: PhaseType, + run_count: int, ) -> None: if "HUGGINGFACE_API_KEY_PATH" in os.environ: os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() @@ -48,18 +36,16 @@ def setup( from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper - super().setup(distributed, run, multi_stage, runner, data, phase) + super().setup(multi_stage, runner, data, run_count) - self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( - self._multi_stage, runner=self._runner - ) + hf_model = multi_stage.config_class.get_huggingface_model_for_causal_lm_class()(multi_stage, runner=runner) # For reporting purposes, just to indicate it is from Fast-LLM # as lm_eval.simple_evaluate will take it for results['config']['model'] - self._hf_model.config.name_or_path = type(self._hf_model).__name__ + hf_model.config.name_or_path = type(hf_model).__name__ self._flm_wrapper = FastLLMLmEvalWrapper( - model=self._hf_model, + model=hf_model, tokenizer=self._config.tokenizer.get_tokenizer(), truncation=self._config.truncation, logits_cache=self._config.logits_cache, @@ -73,18 +59,8 @@ def setup( def run( self, - training_progress: TrainingProgress | None = None, - run_index: int | None = None, - ) -> EvaluationMetrics: + run_index: int | None, + metrics: dict[str, typing.Any], + ) -> None: assert self._is_setup - - # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ - completed_steps = 0 if training_progress is None else training_progress.completed_steps - - self._flm_wrapper.run(self._config.cli_args, completed_steps, self._run.index) - - # lm_eval logs to disc, wandb and prints to screen itself - return EvaluationMetrics() - - def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: - return None + self._flm_wrapper.run(self._config.cli_args, metrics.get("completed_steps", 0), self._run.index) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index bc42515e7..1b41f21c5 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -15,7 +15,7 @@ from fast_llm.core.distributed import gather_object, safe_barrier, scatter_object from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results -from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.attention.rotary.config import NoRotaryConfig @@ -28,7 +28,7 @@ class FastLLMLmEvalWrapper(lm_eval.api.model.TemplateLM): def __init__( self, - model: HuggingfaceBaseModelForCausalLM, + model: HuggingfacePreTrainedModel, tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, truncation: bool | None = False, logits_cache: bool = True, diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index ccde838e8..9ac6c5ccf 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,9 +1,12 @@ +import abc import logging import typing from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast +from fast_llm.data.batch.config import BatchPreprocessingConfig from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig +from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import MultiStageModel @@ -77,6 +80,13 @@ def from_pretrained( model.initialize_weights() return model + @abc.abstractmethod + def get_preprocessing_config( + self, + phase: PhaseType, + ) -> BatchPreprocessingConfig: + pass + def initialize_weights(self, timeout: float | None = None) -> None: assert self._is_setup for stage in self._stages: diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 8696f0a59..1bffa0f0a 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -43,14 +43,14 @@ class BatchConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - _distributed: DistributedConfig = Field( + distributed: DistributedConfig = Field( init=False, desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) def setup(self, distributed_config: DistributedConfig) -> None: - self._distributed = distributed_config + self.distributed = distributed_config @functools.cached_property def num_inputs(self) -> int: @@ -73,19 +73,19 @@ def _validate(self) -> None: if self.micro_batch_size is None: self.micro_batch_size = 1 self.batch_size = ( - self.micro_batch_size * self.sequential_micro_batches * self._distributed.batch_data_parallel + self.micro_batch_size * self.sequential_micro_batches * self.distributed.batch_data_parallel ) elif self.micro_batch_size is None: self.micro_batch_size = div( - self.batch_size, self.sequential_micro_batches * self._distributed.batch_data_parallel + self.batch_size, self.sequential_micro_batches * self.distributed.batch_data_parallel ) else: self.sequential_micro_batches = div( - self.batch_size, self.micro_batch_size * self._distributed.batch_data_parallel + self.batch_size, self.micro_batch_size * self.distributed.batch_data_parallel ) if self.depth_first_micro_batches is None: if self.breadth_first_micro_batches is None: - if self._distributed.pipeline_parallel > 1: + if self.distributed.pipeline_parallel > 1: self.depth_first_micro_batches = 1 self.breadth_first_micro_batches = self.sequential_micro_batches else: @@ -102,7 +102,7 @@ def _validate(self) -> None: self.sequential_micro_batches, self.breadth_first_micro_batches * self.depth_first_micro_batches ) - if self._distributed.pipeline_parallel > 1 and self.depth_first_micro_batches > 1: + if self.distributed.pipeline_parallel > 1 and self.depth_first_micro_batches > 1: raise NotImplementedError("Depth-first pipeline parallelism not yet implemented") super()._validate() diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 4a6f3b3cb..92adfb1a9 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -149,7 +149,7 @@ def run_step( preprocessed: bool = False, ) -> tuple[dict[str, float | int], bool, dict[str, typing.Any] | None]: assert self._is_setup - assert schedule._schedule_config is self._config # Noqa + assert schedule._config is self._config # Noqa if schedule.phase.is_training: assert self._support_training @@ -335,7 +335,6 @@ def _preprocess_data( if not preprocessed: micro_batch_data = self._multi_stage.base_model.preprocess_batch( micro_batch_data, - context.schedule.preprocessed_meta, phase=context.phase, iteration=context.iteration, metrics=context.metrics, diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index fa25c914d..b0b72763e 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -1,7 +1,7 @@ -import abc import dataclasses import functools import logging +import math import typing import warnings @@ -10,11 +10,12 @@ import torch.utils import torch.utils.data +from fast_llm.config import Configurable +from fast_llm.data.batch.config import PreprocessedBatch from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.multi_stage.multi_stage import MultiStageModel from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig, StepType -from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -117,18 +118,19 @@ def get_stage_index(self, num_stages) -> int: return self.stage if self.type_ == StepType.forward else 2 * num_stages - 1 - self.stage -class Schedule(abc.ABC): +class Schedule[ConfigType: ScheduleConfig](Configurable[ConfigType]): def __init__( self, + config: ConfigType, + *, multi_stage: MultiStageModel, - batch_config: BatchConfig, - schedule_config: ScheduleConfig, + batch_meta: PreprocessedBatch, distributed_config: DistributedConfig, phase: PhaseType, ): + super().__init__(config) self._multi_stage = multi_stage - self._batch_config = batch_config - self._schedule_config = schedule_config + self._batch_config = batch_meta.config.batch self._distributed_config = distributed_config self._num_stages = len(self._multi_stage.stages) self._phase = phase @@ -138,9 +140,10 @@ def __init__( warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. - self._preprocessed_meta = self._multi_stage.base_model.preprocess_meta( - self._batch_config, + self._preprocessed_meta = self._multi_stage.base_model.preprocess_batch( + batch_meta, phase=self._phase, + iteration=0, ) self._steps, self._first_grad_stage = self._create_steps() @@ -155,7 +158,7 @@ def __init__( self._setup_throttle_steps() self._setup_metas() - if self._schedule_config.debug_schedule: + if self._config.debug_schedule: logger.info(f"{self._phase.value} schedule:\n{self._steps}") @property @@ -166,10 +169,6 @@ def phase(self) -> PhaseType: def batch_config(self) -> BatchConfig: return self._batch_config - @property - def preprocessed_meta(self) -> list[tuple[TensorMeta, dict]]: - return self._preprocessed_meta - def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]: return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank]) @@ -281,7 +280,7 @@ def _setup_restore_steps(self, weight_buffer_indices: dict[int, int]) -> None: for step in device_steps: buffer_index = weight_buffer_indices[step.stage] if buffer_contents.get(buffer_index) != step.stage: - if self._schedule_config.data_overlap and self._distributed_config.use_cuda: + if self._config.data_overlap and self._distributed_config.use_cuda: step.restore_step = device_steps[buffer_last_used.get(buffer_index, -1) + 1] step.restore_event = torch.cuda.Event() else: @@ -378,7 +377,7 @@ def _setup_send_recv_steps(self) -> None: launch_step.recv_launch.append(recv_step) send_step.send_to = launch_step recv_step.recv_step = launch_step - if self._schedule_config.pipeline_overlap and self._distributed_config.use_cuda: + if self._config.pipeline_overlap and self._distributed_config.use_cuda: recv_step.recv_event = torch.cuda.Event() def _validate_send_recv_steps(self) -> None: @@ -449,12 +448,12 @@ def _validate_send_recv_steps(self) -> None: raise RuntimeError(f"Cannot find valid timeline for {self}, \nStatuses:{msg}") def _setup_throttle_steps(self) -> None: - if not self._schedule_config.throttle_cpu or not self._distributed_config.use_cuda: + if not self._config.throttle_cpu or not self._distributed_config.use_cuda: return for device_steps in self._device_steps: for i, step in enumerate(device_steps): - if i >= self._schedule_config.throttle_cpu_delay and i % self._schedule_config.throttle_cpu_rate == 0: - throttle_step = device_steps[i - self._schedule_config.throttle_cpu_delay] + if i >= self._config.throttle_cpu_delay and i % self._config.throttle_cpu_rate == 0: + throttle_step = device_steps[i - self._config.throttle_cpu_delay] throttle_step.throttle_event = torch.cuda.Event() step.throttle_step = throttle_step @@ -548,3 +547,10 @@ def get_compute_usage( @functools.cached_property def compute_usage(self) -> tuple[int | None, int | None]: return self.get_compute_usage(True, False), self.get_compute_usage(True, True) + + def get_compute_metrics(self, time_per_iteration: float) -> dict[str, float]: + model_compute, hardware_compute = self.compute_usage + return { + "model_tflops": math.nan if model_compute is None else model_compute / time_per_iteration, + "hardware_tflops": math.nan if hardware_compute is None else hardware_compute / time_per_iteration, + } diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 867cca984..9a1dfcc04 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -32,7 +32,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator + from fast_llm.engine.evaluation.evaluator import Evaluator + from fast_llm.engine.training.trainer import Trainer @config_class() @@ -163,12 +164,9 @@ def get_evaluator( self, name: str, batch_config: BatchConfig, - data_load_num_proc: int, - train_iters: int | None = None, - ) -> "TrainingEvaluator": - from fast_llm.engine.training.trainer import TrainingEvaluator - - return TrainingEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + num_workers: int, + ) -> "Evaluator": + return self.evaluator.get_evaluator(name, batch_config, num_workers) @config_class() @@ -288,12 +286,6 @@ class TrainingConfig(Config): train_iters: int = Field( default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) - test_iters: int = Field( - default=0, - desc="Number of iterations for the test phase at the end of training. Setting to 0 will disable the test phase.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) num_workers: int = Field( default=2, desc="Number of data loading processes for each data iterator.", diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 68c73bf70..0290a6468 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -11,30 +11,18 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import allreduce_scalar, safe_barrier from fast_llm.data.data.abstract import Data -from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.evaluation.evaluator import ( - EvaluationMetrics, - Evaluator, - EvaluatorRunner, - EvaluatorSamplingParameters, - TrainingProgress, -) from fast_llm.engine.multi_stage.config import StageMode -from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer -from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule from fast_llm.engine.training.config import ( TrainerConfig, TrainingCheckpointBaseConfig, TrainingCheckpointConfig, - TrainingEvaluatorConfig, ) from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, log_memory_usage @@ -43,99 +31,26 @@ logger = logging.getLogger(__name__) -class TrainingEvaluator[ConfigType: TrainingEvaluatorConfig](Evaluator[ConfigType]): - evaluator: Evaluator - - def __init__( - self, - name: str, - eval_config: TrainingEvaluatorConfig, - batch_config: BatchConfig, - data_load_num_proc: int, - train_iters: int | None = None, - ): - super().__init__(name, eval_config, batch_config, data_load_num_proc, train_iters) - - self._train_iters = 0 if self._train_iters is None else self._train_iters - - self.evaluator = eval_config.evaluator.get_evaluator(name, batch_config, data_load_num_proc, train_iters) - - def setup( - self, - distributed: Distributed, - run: Run, - multi_stage: FastLLMModel, - runner: ScheduleRunner, - data: Data, - phase: PhaseType, - ) -> None: - self.evaluator.setup( - distributed, - run, - multi_stage, - runner, - data, - phase, - ) - - def run( - self, - training_progress: TrainingProgress | None = None, - run_index: int | None = None, - ) -> EvaluationMetrics: - # Run index must be None because it is defined here to be passed to actual evaluator - assert run_index is None - - # Training progress can be None as it can be run in a training - # run without training, just evaluation - if training_progress is None: - done = True - completed_steps = 0 - else: - done = training_progress.done - completed_steps = training_progress.completed_steps - - if (done and self.config.enabled()) or self.config.enabled(completed_steps): - return self.evaluator.run(training_progress, run_index=self._config.get_run_count(completed_steps - 1)) - else: - return EvaluationMetrics() - - def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: - name_samples = self.evaluator.get_sampling_parameters() - if name_samples is None: - return None - run_count = self._config.get_run_count( - self._train_iters, - # There may be an extra evaluation after the last training step.s - not self._config.enabled(self._train_iters), - ) - return EvaluatorSamplingParameters(name_samples.dataset_name, name_samples.num_samples * run_count) - - class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): # TODO: Generalize data, schedule, logging, etc. _is_setup: bool = False _distributed: Distributed _run: Run _wandb: Wandb - _optimizer: Optimizer + _optimizer: Optimizer | None _completed_steps: int - _is_evaluation_only: bool - - _evaluator_runner: EvaluatorRunner - def __init__(self, config: TrainerConfig): super().__init__(config) - self._is_evaluation_only = config.training.train_iters == 0 + self._do_train = config.training.train_iters > 0 self._data = self._get_data() log_main_rank("Creating model...") self._multi_stage = self._config.model.get_model_class()( self._config.model, - optimizer_state_names=self._config.optimizer.state_names() if not self._is_evaluation_only else (), + optimizer_state_names=self._config.optimizer.state_names() if self._do_train else (), ) self._reference_models = {} for name, reference_config in self._config.reference_models.items(): @@ -152,47 +67,22 @@ def __init__(self, config: TrainerConfig): ) self._loss_definitions = self._multi_stage.base_model.get_loss_definitions() - if not self._is_evaluation_only: - steps_per_split = { - PhaseType.training: {PhaseType.training.value.lower(): self._config.training.train_iters}, - PhaseType.test: {PhaseType.test.value.lower(): self._config.training.test_iters}, - } - - self._samples_per_split = { - phase: { - dataset_name: self._config.batch.batch_size * steps - for dataset_name, steps in datasets.items() - if steps > 0 - } - for phase, datasets in steps_per_split.items() - } - # Prune empty phases. - self._samples_per_split = {k: v for k, v in self._samples_per_split.items() if len(v) > 0} - - # Setup the schedules - self._schedule = { - phase: { - dataset_name: Schedule( - multi_stage=self._multi_stage, - batch_config=self._config.batch, - schedule_config=self._config.schedule, - distributed_config=self._config.model.distributed, - phase=phase, - ) - for dataset_name in datasets - } - for phase, datasets in self._samples_per_split.items() - } - else: - self._samples_per_split = {} - - self._evaluator_runner = EvaluatorRunner( - evaluator_configs=self._config.training.evaluators, - batch_config=self._config.batch, - data_load_num_proc=self._config.training.num_workers, - train_iters=self._config.training.train_iters, - wandb_config=self._config.training.wandb, - ) + if self._do_train: + self._training_samples = self._config.batch.batch_size * self._config.training.train_iters + self._preprocessing_config = self._multi_stage.get_preprocessing_config(PhaseType.training) + self._schedule = Schedule( + config=self._config.schedule, + multi_stage=self._multi_stage, + batch_meta=self._preprocessing_config.get_batch_meta(), + distributed_config=self._config.model.distributed, + phase=PhaseType.training, + ) + + self._evaluators = { + name: config.get_evaluator(name, self._config.batch, self._config.training.num_workers) + for name, config in self._config.training.evaluators.items() + if config.enabled() + } def setup(self, distributed: Distributed, run: Run) -> None: assert distributed.config is self._config.model.distributed @@ -204,18 +94,14 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the model. with torch.no_grad(): log_main_rank("Setting up model...") - self._multi_stage.setup( - distributed, mode=StageMode.inference if self._is_evaluation_only else StageMode.training - ) + self._multi_stage.setup(distributed, mode=StageMode.training if self._do_train else StageMode.inference) for name, reference_model in self._reference_models.items(): log_main_rank(f"Setting up `{name}` reference model...") reference_model.fast_llm_model.setup(distributed, StageMode.inference) reference_model.setup() # Setup the optimizer. - if self._is_evaluation_only: - self._optimizer = None - else: + if self._do_train: param_groups, grads_for_norm = self._multi_stage.get_param_groups(ParamGroup) self._optimizer = self._config.optimizer.optimizer_cls( self._config.optimizer, @@ -223,59 +109,47 @@ def setup(self, distributed: Distributed, run: Run) -> None: grads_for_norm=grads_for_norm, distributed=self._distributed, ) + else: + self._optimizer = None # Setup the schedules. with torch.no_grad(): self._runner.setup(distributed, self._optimizer) # Setup the datasets. log_main_rank("Preparing datasets...") - sampling_parameters = {} - preprocessing_configs = {} - for phase, datasets in self._samples_per_split.items(): - for dataset_name, samples in datasets.items(): - sampling_parameters[dataset_name] = self._get_sampling_parameters({"num_samples": samples}) - preprocessing_configs[dataset_name] = self._get_preprocessing_config(phase) - for eval_sampling_params in self._evaluator_runner.get_sampling_parameters(): - sampling_parameters[eval_sampling_params.dataset_name] = self._get_sampling_parameters( - {"num_samples": eval_sampling_params.num_samples} - ) - preprocessing_configs[eval_sampling_params.dataset_name] = self._get_preprocessing_config( - PhaseType.inference - ) - self._data.setup( - distributed, - sampling_parameters, - preprocessing_configs, - None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", - timeout=self._config.training.timeout, - ) - # Must be called with all arguments set up - self._evaluator_runner.setup( - distributed=self._distributed, - run=self._run, - multi_stage=self._multi_stage, - runner=self._runner, - data=self._data, - wandb=self._wandb, - phase=PhaseType.inference if self._is_evaluation_only else PhaseType.validation, + self._data.setup(None if run.experiment_directory is None else run.experiment_directory / "dataset_cache") + self._data.sample_dataset( + PhaseType.training, + self._preprocessing_config, + self._training_samples, ) + for evaluator in self._evaluators.values(): + run_count = self._config.training.evaluators[name].get_count(self._config.training.train_iters) + # There may be an extra evaluation after the last training step. + if not self._config.training.evaluators[name].enabled(self._config.training.train_iters): + run_count += 1 + evaluator.setup(multi_stage=self._multi_stage, runner=self._runner, data=self._data, run_count=run_count) + + # Make sure everyone is done before continuing. + safe_barrier(distributed.world_group, "data_preparation", self._config.training.timeout) + self._is_setup = True @abc.abstractmethod def _get_data(self) -> Data: pass - def _get_sampling_parameters( - self, parameters: dict[str, typing.Any], *, _return_dict: bool = False - ) -> SamplingParameters | dict[str, typing.Any]: - return parameters if _return_dict else SamplingParameters(**parameters) - - def _get_preprocessing_config( - self, phase: PhaseType, *, _return_dict: bool = False - ) -> PreprocessingConfig | dict[str, typing.Any]: - return {} if _return_dict else NullPreprocessingConfig() + def _get_completion_metrics(self) -> dict[str, int | float]: + assert self._is_setup + return { + "total_steps": self._config.training.train_iters, + "completed_steps": self._completed_steps, + "consumed_samples": self._consumed_samples, + "consumed_tokens": self._consumed_tokens, + "percent_done": 100 * self._completed_steps / self._config.training.train_iters, + } @property def _consumed_samples(self) -> int: @@ -299,44 +173,12 @@ def _run_training(self) -> None: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"After initial setup", str)) self._run.save_logged_tensors("init") - if self._is_evaluation_only: - assert len(self._samples_per_split) == 0 - - if PhaseType.training in self._samples_per_split: - done = self._completed_steps >= self._config.training.train_iters - if done: - metrics = {} - log_main_rank("Training already completed, nothing to do ...") - else: - done, metrics = self._train() + if not self._do_train: + self._run_evaluators(True, {}) + elif self._completed_steps >= self._config.training.train_iters: + log_main_rank("Training already completed, nothing to do ...") else: - metrics = {} - done = True - self._evaluator_runner.run( - metrics=metrics, - # This is set to ensure that evaluators like lm_eval log results at the correct step if a checkpoint was loaded. - training_progress=TrainingProgress( - done=done, - completed_steps=self._completed_steps, - consumed_samples=self._consumed_samples, - consumed_tokens=self._consumed_tokens, - ), - ) - - if done and PhaseType.test in self._samples_per_split: - log_main_rank(lambda: f"Running test phase ...") - test_iterator = self._get_data_iterator(PhaseType.test.value.lower()) - metrics_key = PhaseType.test.value - metrics[metrics_key] = self._evaluate_loss( - data_iterator=test_iterator, - phase=PhaseType.test, - num_iters=self._config.training.test_iters, - ) - formatted_metrics = format_metrics(metrics[metrics_key], self._loss_definitions, PhaseType.test) - log_main_rank(formatted_metrics) - self._wandb.alert("Testing results", formatted_metrics, "WARN") - # TODO: This may erase some metrics. - self._wandb.log_metrics(self._completed_steps, metrics, commit=True) + self._train() def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # Tracking loss. @@ -357,8 +199,6 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._config.training.prefetch_factor, ) - has_test_phase = PhaseType.test in self._samples_per_split - log_main_rank("Training ...") # TODO: Synchronization is probably unnecessary. @@ -380,7 +220,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # (Also preprocessing adds overhead) reduced_losses, update_successful, train_metrics = self._runner.run_step( train_iterator, - self._schedule[PhaseType.training][PhaseType.training.value.lower()], + self._schedule, iteration=self._completed_steps, return_metrics=is_logging, ) @@ -410,34 +250,21 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: remaining_time = average_time_per_iteration * ( self._config.training.train_iters - self._completed_steps ) - model_compute, hardware_compute = self._schedule[PhaseType.training][ - PhaseType.training.value.lower() - ].compute_usage - model_tflops = math.nan if model_compute is None else model_compute / time_per_iteration - hardware_tflops = ( - math.nan if hardware_compute is None else hardware_compute / time_per_iteration - ) - metrics_key = PhaseType.training.value metrics[metrics_key] = { - "train_iters": self._config.training.train_iters, "batch_size": self._config.batch.batch_size, - "iteration": self._completed_steps, **{ name: (value / advanced_iters if advanced_iters > 0 else float("nan")) for name, value in total_losses.items() }, - "consumed_samples": self._consumed_samples, - "consumed_tokens": self._consumed_tokens, + **self._get_completion_metrics(), "step_time_ms": time_per_iteration * 1000, "step_time_average_ms": average_time_per_iteration * 1000, "remaining_time": remaining_time, "completion_time": time.time() + remaining_time, - "percent_done": 100 * self._completed_steps / self._config.training.train_iters, "skipped_iters": skipped_iters, "nan_iters": nan_iters, - "model_tflops": model_tflops, - "hardware_tflops": hardware_tflops, + **self._schedule.get_compute_metrics(time_per_iteration), "tokens_per_sec_per_gpu": ( (self._config.batch.sequence_length * self._config.batch.batch_size) / self._config.model.distributed.world_size @@ -469,21 +296,10 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: stop = done or self._config.training.shutdown.enabled(self._completed_steps) # Evaluation - # TODO: Adjust valid iterator length. - self._evaluator_runner.run( - metrics=metrics, - training_progress=TrainingProgress( - done=done, - completed_steps=self._completed_steps, - consumed_samples=self._consumed_samples, - consumed_tokens=self._consumed_tokens, - ), - ) + self._run_evaluators(done, metrics) if is_main_rank() and metrics: - self._wandb.log_metrics(self._completed_steps, metrics, commit=not (done and has_test_phase)) - - stop = done or self._config.training.shutdown.enabled(self._completed_steps) + self._wandb.log_metrics(self._completed_steps, metrics, commit=True) if self._config.training.export.enabled(None if done else self._completed_steps): self._save_checkpoint(self._config.training.export, metrics) @@ -523,14 +339,14 @@ def _prepare_training_state(self) -> None: ) self._multi_stage.load_checkpoint(self._config.pretrained) else: - if self._is_evaluation_only: + if not self._do_train: raise ValueError( "Evaluation mode, model need to be trained first or pretrained checkpoint is provided for loading" ) log_main_rank(f"Initializing training state from scratch...") self._multi_stage.initialize_weights() - if not self._is_evaluation_only: + if self._do_train: self._optimizer.reset_state() self._completed_steps = 0 else: @@ -608,7 +424,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout) ) assert metadata is not None - if not self._is_evaluation_only: + if self._do_train: self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. @@ -636,3 +452,15 @@ def _get_last_checkpoint(self) -> int | None: iteration = -1 iteration = self._run.broadcast_int(iteration) return iteration if iteration >= 0 else None + + def _run_evaluators(self, done: bool, metrics: dict[str, typing.Any] | None = None) -> None: + for name, evaluator in self._evaluators.items(): + if self._config.training.evaluators[name].enabled(None if done else self._completed_steps): + evaluator.run( + run_index=self._config.get_run_count(self._completed_steps - 1), + metrics=(evaluator_metrics := self._get_completion_metrics()), + ) + if metrics is not None: + if "evaluations" not in metrics: + metrics["evaluations"] = {} + metrics["evaluations"][name] = evaluator_metrics diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 85b9bde1d..57b9b82b8 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block @@ -23,7 +23,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert +from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -116,6 +116,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) ) + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return safe_merge_dicts([loss.get_preprocessing_config(phase) for loss in self.losses]) + def get_output_weights(self) -> list[torch.Tensor]: return [self.output_weights] diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index 177a681a4..ad8ff49d9 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -2,6 +2,7 @@ import torch +from fast_llm.engine.distributed.config import PhaseType from fast_llm.layers.language_model.loss.config import LanguageModelDPOLossConfig, LanguageModelLossKwargs from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward @@ -18,6 +19,9 @@ def __init__(self, *args, **kwargs): if self._vocab_parallel: raise NotImplementedError() + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return {"use_preference_spans": True} + def forward_backward( self, logits: "torch.Tensor", diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index f1f65ac39..9506b3d80 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -5,7 +5,7 @@ from fast_llm.config import Configurable from fast_llm.core.ops import split_op -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs from fast_llm.utils import Assert @@ -47,6 +47,9 @@ def forward_backward( ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass + def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: + return {} + @property def name(self) -> str: return self._name @@ -61,16 +64,8 @@ def _prepare_target( kwargs: dict[str, typing.Any], split_index: int = 0, *, - multi_token_format: bool = False, sequence_parallel: bool = True, ) -> torch.Tensor | None: - # MTP shift - if multi_token_format and self._prediction_heads > 1: - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - target = target.unflatten( - 0, (kwargs[LanguageModelKwargs.batch_dim].size, sequence_q + self._prediction_heads - 1) - )[:, self._prediction_distance : self._prediction_distance + sequence_q].flatten(0, 1) - # Get the local chunk. if sequence_parallel and self._sequence_parallel: target = split_op(target, self._parallel_dim.group, 0) @@ -93,9 +88,7 @@ def _get_grad_output(self, kwargs: dict[str, typing.Any]) -> float | None: return grad_output def _get_labels(self, kwargs: dict[str, typing.Any], split_index: int = 0): - return self._prepare_target( - kwargs[LanguageModelLossKwargs.labels], kwargs, split_index, multi_token_format=True - ) + return self._prepare_target(kwargs[LanguageModelLossKwargs.labels], kwargs, split_index) def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 84b945a67..a25b3b0f8 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -92,13 +92,11 @@ PhaseType.training: _TRAINING_METRIC_FORMAT_KEYS, PhaseType.validation: _VALIDATION_METRIC_FORMAT_KEYS, PhaseType.inference: _VALIDATION_METRIC_FORMAT_KEYS, - PhaseType.test: _VALIDATION_METRIC_FORMAT_KEYS, } _METRIC_FORMATS = { PhaseType.training: _TRAINING_METRIC_FORMATS, PhaseType.validation: _VALIDATION_METRIC_FORMATS, PhaseType.inference: _VALIDATION_METRIC_FORMATS, - PhaseType.test: _VALIDATION_METRIC_FORMATS, } diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index ddcbcf696..238c7cfc0 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -49,16 +49,6 @@ class GPTBatchConfig(BatchConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - use_loss_masking_spans: bool = Field( - default=False, - desc="Read loss masking spans from the dataset.", - hint=FieldHint.feature, - ) - use_preference_spans: bool = Field( - default=False, - desc="Read dpo data (chosen and rejected spans) from the dataset.", - hint=FieldHint.feature, - ) truncate_documents: bool | None = Field( default=True, desc=( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2f96f6f91..33519a415 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -5,6 +5,7 @@ import torch +from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.distributed.config import DistributedConfig, PhaseType @@ -138,8 +139,11 @@ def _head_reference_models(self) -> set[str]: class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): - # TODO: Can we drop class? - pass + def get_preprocessing_config( + self, + phase: PhaseType, + ) -> LanguageModelBatchPreprocessingConfig: + return LanguageModelBatchPreprocessingConfig(phase=phase, **self._base_model.get_preprocessing_config(phase)) class GPTInferenceRunner(InferenceRunner): diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index e65556501..ce789e4dc 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -1,10 +1,6 @@ import logging -import typing -from fast_llm.data.batch.language_model import LanguageModelBatchPreprocessingConfig from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -17,27 +13,3 @@ def _get_data(self) -> GPTData: config=self._config.data, distributed_config=self._config.model.distributed, ) - - def _get_sampling_parameters( - self, parameters: dict[str, typing.Any], *, _return_dict: bool = False - ) -> SamplingParameters | dict[str, typing.Any]: - parameters = super()._get_sampling_parameters(parameters, _return_dict=True) - parameters.update( - { - "sequence_length": self._config.batch.sequence_length, - "truncate_documents": self._config.batch.truncate_documents, - "extra_tokens": self._config.model.base_model.head.prediction_heads, - } - ) - return parameters if _return_dict else SamplingParameters(**parameters) - - def _get_preprocessing_config( - self, phase: PhaseType, *, _return_dict: bool = False - ) -> LanguageModelBatchPreprocessingConfig | dict[str, typing.Any]: - out = { - "phase": phase, - "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_spans": self._config.batch.use_preference_spans, - **self._multi_stage.base_model.get_preprocessing_config(phase), - } - return out if _return_dict else LanguageModelBatchPreprocessingConfig.from_dict(out) diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py index 43a8f8885..780cdd294 100644 --- a/fast_llm/models/multimodal/trainer.py +++ b/fast_llm/models/multimodal/trainer.py @@ -1,6 +1,7 @@ import logging import typing +from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.models.gpt.trainer import GPTTrainer from fast_llm.models.multimodal.config import MultiModalTrainerConfig @@ -11,7 +12,7 @@ class MultiModalTrainer[ConfigType: MultiModalTrainerConfig](GPTTrainer[ConfigType]): def _get_preprocessing_config( self, *, _return_dict: bool = False - ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + ) -> LanguageModelBatchPreprocessingConfig | dict[str, typing.Any]: out = super()._get_preprocessing_config(_return_dict=True) out["image_patches"] = { "type": "image_patch", diff --git a/tests/data/common.py b/tests/data/common.py index 26aeda845..fd5ae0692 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -50,7 +50,7 @@ def get_sampling_data( ), preprocessing=preprocessing, cache_directory=cache_directory, - distributed=distributed, + distributed_config=DistributedConfig(use_cuda=torch.cuda.is_available()), dataset_name=phase.value, ) @@ -74,17 +74,11 @@ def get_test_data_and_compare_samples( preprocessing: LanguageModelPreprocessingConfig, ) -> GPTData: distributed_config = DistributedConfig(seed=87522, use_cuda=torch.cuda.is_available()) - distributed = Distributed(distributed_config) if isinstance(samples_per_dataset, int): - samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} - - sampling_parameters = { - dataset_name: SamplingParameters(num_samples=num_samples, sequence_length=sequence_length) - for dataset_name, num_samples in samples_per_dataset.items() - } + samples_per_dataset = {PhaseType.training.value: samples_per_dataset} if isinstance(expected_samples, list): - expected_samples = {PhaseType.training.value.lower(): expected_samples} + expected_samples = {PhaseType.training.value: expected_samples} assert "sampling" not in config config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) @@ -96,12 +90,9 @@ def get_test_data_and_compare_samples( preprocessing, {"batch": batch_config, "type": None} ) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) - data.setup( - distributed, - sampling_parameters, - {dataset_name: preprocessing for dataset_name in samples_per_dataset}, - cache_directory, - ) + data.setup(cache_directory) + for dataset_name, num_samples in samples_per_dataset.items(): + data.sample_dataset(dataset_name, preprocessing, num_samples) tokens = { phase: torch.stack( [ From dd536b880ea9a1482397d95a0654e59529fae253 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 19 Feb 2026 23:58:42 -0500 Subject: [PATCH 4/4] fixes --- examples/mistral.yaml | 1 - fast_llm/data/batch/config.py | 4 +- fast_llm/data/batch/language_model.py | 86 +++--- fast_llm/data/data/gpt/data.py | 4 +- fast_llm/data/dataset/memmap/config.py | 10 +- fast_llm/data/document/language_model.py | 6 - fast_llm/engine/base_model/base_model.py | 1 + fast_llm/engine/evaluation/evaluator.py | 2 +- fast_llm/engine/inference/huggingface.py | 1 + fast_llm/engine/multi_stage/fast_llm_model.py | 6 +- fast_llm/engine/schedule/runner.py | 1 + fast_llm/engine/schedule/schedule.py | 1 + fast_llm/engine/training/trainer.py | 4 +- fast_llm/layers/attention/attention.py | 29 +- fast_llm/layers/attention/preprocessing.py | 58 ---- fast_llm/layers/attention/rotary/rotary.py | 10 +- fast_llm/layers/language_model/embedding.py | 2 +- fast_llm/layers/language_model/head.py | 12 +- .../layers/language_model/language_model.py | 20 +- fast_llm/layers/language_model/loss/config.py | 2 +- fast_llm/layers/language_model/loss/dpo.py | 4 +- fast_llm/layers/language_model/loss/loss.py | 14 +- .../language_model/multi_token_prediction.py | 11 +- fast_llm/layers/ssm/mamba.py | 5 +- fast_llm/models/gpt/conversion/mtp_llama.py | 6 +- fast_llm/models/gpt/huggingface.py | 6 +- fast_llm/models/gpt/model.py | 26 +- fast_llm/models/multimodal/model.py | 9 +- tests/data/test_preprocessing.py | 65 +++++ tests/layers/test_lm_head.py | 57 ++-- tests/test_loss_mask.py | 254 ------------------ tests/utils/distributed_configs.py | 2 +- 32 files changed, 255 insertions(+), 464 deletions(-) delete mode 100644 fast_llm/layers/attention/preprocessing.py create mode 100644 tests/data/test_preprocessing.py delete mode 100644 tests/test_loss_mask.py diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 904325c5c..ec045e3bb 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -8,7 +8,6 @@ training: evaluator: type: loss iterations: null - test_iters: 0 batch: sequence_length: 4096 micro_batch_size: 2 diff --git a/fast_llm/data/batch/config.py b/fast_llm/data/batch/config.py index 360a07fb6..61dd3bdda 100644 --- a/fast_llm/data/batch/config.py +++ b/fast_llm/data/batch/config.py @@ -48,6 +48,8 @@ def _validate(self) -> None: Assert.custom(isinstance, self.tokenizer, (TokenizerConfig, NullPreprocessingConfig)) def get_batch_meta(self) -> "PreprocessedBatch": + import torch + from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.data.document.token import TokenDocument @@ -92,7 +94,7 @@ def __init__(self, config: ConfigType, micro_batches: list[MicroBatchType]): self._micro_batches = micro_batches @property - def micro_batches(self) -> list[MicroBatch]: + def micro_batches(self) -> list[MicroBatchType]: return self._micro_batches def __len__(self) -> int: diff --git a/fast_llm/data/batch/language_model.py b/fast_llm/data/batch/language_model.py index 06bc90e37..012966799 100644 --- a/fast_llm/data/batch/language_model.py +++ b/fast_llm/data/batch/language_model.py @@ -7,6 +7,7 @@ from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.tensor import TensorMeta @dataclasses.dataclass @@ -19,6 +20,7 @@ class LanguageModelMicroBatch(MicroBatch): num_tokens: int # Number of tokens in the micro-batch excluding padding at the end. sequence_length: int # Total number of tokens across all micro-batches, including padding. document_lengths: list[int] + is_meta: bool labels: list[torch.Tensor] = dataclasses.field(default_factory=list) prediction_masks: list[torch.Tensor] = dataclasses.field(default_factory=list) cumulative_lengths_q: torch.Tensor | None = None @@ -59,7 +61,9 @@ def from_documents( config: ConfigType, device: torch.device | None = None, ) -> typing.Self: - batch = LanguageModelBatch.from_documents(documents, pad_to_size=config.total_length) + batch = LanguageModelBatch.from_documents( + documents, pad_to_size=config.batch.micro_batch_size * config.total_length + ) return cls.from_batch(batch, config=config, device=device) @classmethod @@ -72,6 +76,7 @@ def from_batch( if device is None: device = batch.tokens.tokens.device batch.to_device_(device) + is_meta = device.type == "meta" token_dim = TensorDim( "token", @@ -98,50 +103,57 @@ def from_batch( sequence_k = sequence_k_past + token_dim.size sequence_k_dim = TensorDim("sequence_k", sequence_k) cropped_sample = batch.crop(sequence_k_past, sequence_k) - + if is_meta: + tokens = TensorMeta.from_dims( + (token_dim,), tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 + ) + else: + tokens = batch.tokens.tokens[sequence_k_past:sequence_k] micro_batch = LanguageModelMicroBatch( - tokens=batch.tokens.tokens[sequence_k_past:sequence_k], + tokens=tokens, token_dim=token_dim, hidden_token_dim=hidden_token_dim, sequence_k_dim=sequence_k_dim, num_tokens=min(sequence_k, batch.num_tokens) - sequence_k_past, sequence_length=config.batch.sequence_length, document_lengths=batch.tokens.lengths, + is_meta=is_meta, ) - if config.return_cumulative_sequence_lengths: - micro_batch.cumulative_lengths_q, micro_batch.cumulative_lengths_k = ( - cropped_sample.tokens.get_cumulative_lengths(device) - ) - if config.return_max_sequence_lengths: - micro_batch.max_length_q, micro_batch.max_length_k = cropped_sample.tokens.get_max_lengths(device) - if config.return_document_index: - micro_batch.document_index = cropped_sample.tokens.get_document_index() - if config.return_position_index: - micro_batch.position_index = cropped_sample.tokens.get_position_index() - - for prediction_distance in range(1, config.predicted_tokens + 1): - label_begin = sequence_k_past + prediction_distance - label_end = sequence_k + prediction_distance - label_tokens = batch.tokens.crop(label_begin, label_end) - labels = label_tokens.tokens.clone() - - # Apply loss masking spans. - if config.use_loss_masking_spans and batch.loss_masking_spans is not None: - for span_begin, span_end in batch.loss_masking_spans.crop(label_begin, label_end).ranges: - labels[span_begin:span_end] = -100 - - # Mask cross-document predictions. - document_end = 0 - for length in label_tokens.lengths: - document_end += length - labels[max(document_end - prediction_distance, 0) : document_end] = -100 - - # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. - micro_batch.labels.append(labels) - if config.return_prediction_mask: - # TODO: Does the prediction mask really need all sources of masking? - # (i.e. lack of labels doesn't mean we can't do predictions and compute other losses.) - micro_batch.prediction_masks.append(labels > 0) + if not is_meta: + if config.return_cumulative_sequence_lengths: + micro_batch.cumulative_lengths_q, micro_batch.cumulative_lengths_k = ( + cropped_sample.tokens.get_cumulative_lengths(device) + ) + if config.return_max_sequence_lengths: + micro_batch.max_length_q, micro_batch.max_length_k = cropped_sample.tokens.get_max_lengths(device) + if config.return_document_index: + micro_batch.document_index = cropped_sample.tokens.get_document_index() + if config.return_position_index: + micro_batch.position_index = cropped_sample.tokens.get_position_index() + + for prediction_distance in range(1, config.predicted_tokens + 1): + label_begin = sequence_k_past + prediction_distance + label_end = sequence_k + prediction_distance + label_tokens = batch.tokens.crop(label_begin, label_end) + labels = label_tokens.tokens.clone() + + # Apply loss masking spans. + if config.use_loss_masking_spans and batch.loss_masking_spans is not None: + for span_begin, span_end in batch.loss_masking_spans.crop(label_begin, label_end).ranges: + labels[span_begin:span_end] = -100 + + # Mask cross-document predictions. + document_begin = label_tokens.lengths[0] + for length in label_tokens.lengths[1:]: + labels[document_begin : document_begin + prediction_distance] = -100 + document_begin += length + + # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. + micro_batch.labels.append(labels) + if config.return_prediction_mask: + # TODO: Does the prediction mask really need all sources of masking? + # (i.e. lack of labels doesn't mean we can't do predictions and compute other losses.) + micro_batch.prediction_masks.append(labels > 0) micro_batches.append(micro_batch) return cls(micro_batches=micro_batches, config=config) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index e15d95e90..5a24a7631 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -16,6 +16,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.models.gpt.config import GPTBatchConfig @@ -75,7 +76,8 @@ def sample_dataset( sampling = GPTSamplingData( config=self._config.sampling, parameters=sampling_parameters, - preprocessing=config, + # Conversion needed to avoid pickling issues. + preprocessing=LanguageModelPreprocessingConfig.from_dict(config, {"type": "language_model"}, strict=False), cache_directory=self._cache_directory, distributed_config=self._distributed_config, dataset_name=dataset_name, diff --git a/fast_llm/data/dataset/memmap/config.py b/fast_llm/data/dataset/memmap/config.py index ce5ecb06c..ed50f366b 100644 --- a/fast_llm/data/dataset/memmap/config.py +++ b/fast_llm/data/dataset/memmap/config.py @@ -4,11 +4,8 @@ import pathlib import typing -import torch - from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig -from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImagePatchConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig @@ -16,6 +13,9 @@ from fast_llm.utils import Assert, get_unique if typing.TYPE_CHECKING: + import torch + + from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.memmap.abstract import ( MemmapIndexedDatasetReader, MemmapReader, @@ -201,6 +201,8 @@ def writer_class(self) -> "type[PatchWriter]": @property def _expected_buffer_size(self) -> int: + import torch + return ( self.num_patches * self.patch_size * self.data_type.torch.itemsize + ((1 + self.grid_dims) * self.num_patches + self.num_patch_groups + 2 * self.num_documents + 2) @@ -255,6 +257,8 @@ def writer_class(self) -> "type[RangeWriter]": @property def _expected_buffer_size(self) -> int: + import torch + return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize def get_metadata(self) -> dict[str, typing.Any]: diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 23de0605b..c0bccc5be 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -72,12 +72,6 @@ def crop(self, begin: int, end: int) -> typing.Self: def to_device_(self, device: "torch.device | str"): self.tokens.to_device_(device) - if self.loss_masking_spans is not None: - self.loss_masking_spans.to_device_(device) - if self.chosen_spans is not None: - self.chosen_spans.to_device_(device) - if self.rejected_spans is not None: - self.rejected_spans.to_device_(device) if self.image_patches is not None: self.image_patches.to_device_(device) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 195a1508a..945daef89 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -184,6 +184,7 @@ def preprocess_batch( iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, + device: torch.device | None, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 8a3bd7e3d..7cabb06d1 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -87,7 +87,7 @@ def setup( ) -> None: super().setup(multi_stage, runner, data, run_count) - preprocessing_config = self._multi_stage.get_preprocessing_config(PhaseType.validation) + preprocessing_config = self._multi_stage.get_preprocessing_config(self._batch_config, PhaseType.validation) self._data.sample_dataset( self._name, preprocessing_config, run_count * self._config.iterations * self._batch_config.batch_size ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index aa1eaa401..5a07bd51b 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -52,6 +52,7 @@ def __init__( fast_llm_config = config.fast_llm_config config.fast_llm_config = None super().__init__(config, **kwargs) + self._fast_llm_model = fast_llm_model config.fast_llm_config = fast_llm_config self._inference_runner = self.runner_class(fast_llm_model, runner) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 9ac6c5ccf..b1dc37649 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -10,6 +10,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import MultiStageModel +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.functional.triton.pointwise import triton_fill from fast_llm.utils import Assert @@ -81,10 +82,7 @@ def from_pretrained( return model @abc.abstractmethod - def get_preprocessing_config( - self, - phase: PhaseType, - ) -> BatchPreprocessingConfig: + def get_preprocessing_config(self, batch: BatchConfig, phase: PhaseType) -> BatchPreprocessingConfig: pass def initialize_weights(self, timeout: float | None = None) -> None: diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 92adfb1a9..0683153e5 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -344,6 +344,7 @@ def _preprocess_data( "num_micro_batches": batch_config.sequential_micro_batches, "micro_batch_splits": batch_config.micro_batch_splits, }, + device=self._distributed.device, ) for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): kwargs.update(micro_batch_split=micro_batch_split) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index b0b72763e..8c932946b 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -144,6 +144,7 @@ def __init__( batch_meta, phase=self._phase, iteration=0, + device=None, ) self._steps, self._first_grad_stage = self._create_steps() diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 0290a6468..b0f48b408 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -69,7 +69,9 @@ def __init__(self, config: TrainerConfig): if self._do_train: self._training_samples = self._config.batch.batch_size * self._config.training.train_iters - self._preprocessing_config = self._multi_stage.get_preprocessing_config(PhaseType.training) + self._preprocessing_config = self._multi_stage.get_preprocessing_config( + self._config.batch, PhaseType.training + ) self._schedule = Schedule( config=self._config.schedule, multi_stage=self._multi_stage, diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 0eaae34f7..389abfbb3 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -11,7 +11,6 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs -from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta @@ -185,7 +184,7 @@ def _attn_backup( query = ( query.unflatten(1, (self._local_head_groups, self._local_heads_per_group)) .transpose(0, 1) - .view(self._local_head_groups, sq * self._local_heads_per_group, self._config.head_size) + .reshape(self._local_head_groups, sq * self._local_heads_per_group, self._config.head_size) ) # sk, head_group, head_size -> head_group, head_size, sk key = key.movedim(0, 2) @@ -353,7 +352,7 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: ====== Account for varlen ======= - sequence_q_dim: TensorDim = kwargs[AttentionKwargs.sequence_q_dim] + sequence_q_dim: TensorDim = kwargs[AttentionKwargs.token_dim] sequence_k_dim: TensorDim = kwargs[AttentionKwargs.sequence_k_dim] if config.global_: @@ -406,11 +405,14 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: - out = {} - if self._implementation == AttentionImplementation.flash: - out["return_cumulative_sequence_lengths"] = True - out["return_max_sequence_lengths"] = True - return out + return ( + { + "return_cumulative_sequence_lengths": True, + "return_max_sequence_lengths": True, + } + if self._implementation == AttentionImplementation.flash + else {"return_document_index": True} + ) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) @@ -420,7 +422,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None: device = kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + sequence_q = kwargs[AttentionKwargs.token_dim].size if self._config.causal: if ( sequence_length := kwargs[AttentionKwargs.sequence_length] @@ -436,15 +438,12 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non if self._config.window_size is not None: self._backup_attention_mask.triu_(-self._config.window_size + 1) - attention_mask = self._backup_attention_mask[ - None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k - ] + attention_mask = self._backup_attention_mask[None, sequence_k - sequence_q : sequence_k, None, :sequence_k] else: attention_mask = None - preprocess_for_varlen(kwargs, device, return_seq_idx=True) - document_mask = (kwargs[AttentionKwargs.seq_idx][:, None, :] == kwargs[AttentionKwargs.seq_idx][:, :, None])[ - :, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + document_mask = (kwargs[AttentionKwargs.seq_idx][None, :] == kwargs[AttentionKwargs.seq_idx][:, None])[ + None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] if attention_mask is None: attention_mask = document_mask diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py deleted file mode 100644 index a9d9936c5..000000000 --- a/fast_llm/layers/attention/preprocessing.py +++ /dev/null @@ -1,58 +0,0 @@ -import typing - -import torch - -from fast_llm.layers.attention.config import MixerKwargs -from fast_llm.utils import Assert - - -def preprocess_for_varlen( - kwargs: dict[str, typing.Any], - device: torch.device, - return_cu_seqlens: bool = False, - return_max_seqlen: bool = False, - return_seq_idx: bool = False, - return_position_ids: bool = False, -) -> None: - """ - Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 - cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. - Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. - If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally - also contain previous tokens from the first document in micro-sequence. - We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. - """ - - # TODO: ====== Fix (need to know how much first sequence was cropped) ====== - Assert.eq(kwargs[MixerKwargs.sequence_k_dim].global_size, kwargs[MixerKwargs.sequence_q_dim].global_size) - - sequence_lengths = [ - sequence_length - for sequence_lengths in kwargs[MixerKwargs.sequence_lengths] - for sequence_length in sequence_lengths - ] - if return_cu_seqlens: - cu_seqlens_q = torch.tensor([0] + sequence_lengths, dtype=torch.int32, device=device).cumsum( - 0, dtype=torch.int32 - ) - kwargs[MixerKwargs.cu_seqlens_q] = cu_seqlens_q - kwargs[MixerKwargs.cu_seqlens_k] = cu_seqlens_q - if return_max_seqlen: - max_seqlen_q = torch.full((1,), max(sequence_lengths), dtype=torch.int32, device=device) - kwargs[MixerKwargs.max_seqlen_q] = max_seqlen_q - kwargs[MixerKwargs.max_seqlen_k] = max_seqlen_q - if return_seq_idx: - kwargs[MixerKwargs.seq_idx] = torch.cat( - [ - torch.full((sequence_length,), i, dtype=torch.int32, device=device) - for i, sequence_length in enumerate(sequence_lengths) - ] - ) - if return_position_ids: - kwargs[MixerKwargs.position_ids] = torch.cat( - [ - torch.arange(sequence_length, dtype=torch.int32, device=device) - for i, sequence_length in enumerate(sequence_lengths) - ] - ) diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 307256a72..9e28b66c6 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -93,7 +93,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[AttentionKwargs.sequence_length], kwargs[AttentionKwargs.device]) sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k + :, sequence_k - kwargs[AttentionKwargs.token_dim].size : sequence_k ] kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] @@ -124,9 +124,9 @@ def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.d # We preform the calculation in high precision because it matters for rotary embeddings. positions = torch.arange(sequence_length, device=device, dtype=torch.float64) angles = torch.outer(positions, self._get_angle_scales(head_size, device)) - frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + frequencies = torch.polar(torch.ones_like(angles), angles)[:, None, :].to(torch.complex64) frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), head_size, 3 + torch.view_as_real(frequencies).flatten(-2), head_size, 2 ).contiguous() return frequencies @@ -223,9 +223,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._frequencies.T.unsqueeze(1), out=angles.view(-1, 2, self._head_size // 4).permute(1, 0, 2), ) - frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + frequencies = torch.polar(torch.ones_like(angles), angles)[:, None, :].to(torch.complex64) frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), self._head_size, 3 + torch.view_as_real(frequencies).flatten(-2), self._head_size, 2 ).contiguous() # TODO: Support different q and k frequencies. kwargs[AttentionKwargs.rotary_freq_q] = frequencies diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index ed685b416..1c5e51410 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -179,7 +179,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c return 0 def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: - out = {"vocab_size": self.embeddings.vocab_size} + out = {"vocab_size": self._config.vocab_size} if self._config.position_embeddings.enabled: out["return_position_index"] = True return out diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 57b9b82b8..06cc7a2ea 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -47,7 +47,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - prediction_distance: int = 0, + prediction_distance: int = 1, loss_coefficient: float = 1.0, ): super().__init__( @@ -57,9 +57,9 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - Assert.in_range(prediction_distance, 0, self._config.prediction_heads) + Assert.in_range_incl(prediction_distance, 1, self._config.prediction_heads) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == self._config.prediction_heads self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -89,7 +89,7 @@ def __init__( loss_coefficient = ( 1.0 if self._config.prediction_loss_coefficient is None - else self._config.prediction_loss_coefficient[self._prediction_distance] + else self._config.prediction_loss_coefficient[self._prediction_distance - 1] ) self.losses = torch.nn.ModuleList( [ @@ -117,7 +117,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: - return safe_merge_dicts([loss.get_preprocessing_config(phase) for loss in self.losses]) + return safe_merge_dicts(*(loss.get_preprocessing_config(phase) for loss in self.losses)) def get_output_weights(self) -> list[torch.Tensor]: return [self.output_weights] @@ -295,7 +295,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ] def _get_full_loss_name(self, name) -> str: - return name if self._prediction_distance == 0 else f"{name}_{self._prediction_distance}" + return name if self._prediction_distance == 1 else f"{name}_{self._prediction_distance}" @functools.cached_property def _total_loss_name(self) -> str: diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index bdd261d28..099051cfc 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -71,7 +71,7 @@ def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: self.embeddings.get_preprocessing_config(phase), self.decoder.get_preprocessing_config(phase), self.head.get_preprocessing_config(phase), - {} if self.multi_token_prediction is None else self.multi_token_prediction.get_preprocessing_config(phase), + self.multi_token_prediction.get_preprocessing_config(phase), ) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: @@ -79,16 +79,16 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.embeddings.preprocess(kwargs) self.decoder.preprocess(kwargs) self.head.preprocess(kwargs) - if self.multi_token_prediction is not None: - self.multi_token_prediction.preprocess(kwargs) + self.multi_token_prediction.preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? - losses = ( - self.embeddings.get_loss_definitions(count) - + self.decoder.get_loss_definitions(count) - + self.head.get_loss_definitions(count) + return sum( + ( + self.embeddings.get_loss_definitions(count), + self.decoder.get_loss_definitions(count), + self.head.get_loss_definitions(count), + self.multi_token_prediction.get_loss_definitions(count), + ), + [], ) - if self.multi_token_prediction is not None: - losses += self.multi_token_prediction.get_loss_definitions(count) - return losses diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index b6a2ef175..803ac05f1 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -42,7 +42,7 @@ def get_layer( self, distributed_config: DistributedConfig, name: str, - prediction_distance: int = 0, + prediction_distance: int = 1, prediction_heads: int = 1, vocab_parallel: bool = False, num_splits: int = 1, diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index ad8ff49d9..4eb7446e5 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -10,11 +10,11 @@ class LanguageModelDPOLoss[ConfigType: LanguageModelDPOLossConfig](LanguageModelLoss[ConfigType]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self._prediction_distance > 0: + if self._prediction_distance > 1: raise NotImplementedError() if self._num_splits > 1: raise NotImplementedError() - if self._prediction_distance > 0: + if self._prediction_distance > 1: raise NotImplementedError() if self._vocab_parallel: raise NotImplementedError() diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 9506b3d80..07568ccc5 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -18,7 +18,7 @@ def __init__( distributed_config: DistributedConfig, *, name: str, - prediction_distance: int = 0, + prediction_distance: int = 1, prediction_heads: int = 1, vocab_parallel: bool = False, num_splits: int = 1, @@ -26,7 +26,7 @@ def __init__( weight: float = 1.0, ): super().__init__(config) - Assert.in_range(prediction_distance, 0, prediction_heads) + Assert.in_range_incl(prediction_distance, 1, prediction_heads) self._prediction_distance = prediction_distance self._prediction_heads = prediction_heads self._name = name @@ -88,11 +88,17 @@ def _get_grad_output(self, kwargs: dict[str, typing.Any]) -> float | None: return grad_output def _get_labels(self, kwargs: dict[str, typing.Any], split_index: int = 0): - return self._prepare_target(kwargs[LanguageModelLossKwargs.labels], kwargs, split_index) + return self._prepare_target( + kwargs[LanguageModelLossKwargs.labels][self._prediction_distance - 1], kwargs, split_index + ) def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - return None if loss_mask is None else self._prepare_target(loss_mask, kwargs, split_index) + return ( + None + if loss_mask is None + else self._prepare_target(loss_mask[self._prediction_distance - 1], kwargs, split_index) + ) def _get_reference_model_logits(self, reference_model: str, kwargs: dict[str, typing.Any], split_index: int = 0): Assert.incl( diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index a828cacc1..f7979ae53 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -47,9 +47,9 @@ def __init__( peft=self._peft, # The last block only returns the model output. # The previous blocks return a stack of shared_hidden and transformer_output. - return_input=index < self._config.prediction_heads - 1, + return_input=prediction_distance < self._config.prediction_heads, ) - for index in range(1, self._config.prediction_heads) + for prediction_distance in range(2, self._config.prediction_heads + 1) ] ) self.heads = torch.nn.ModuleList( @@ -61,9 +61,9 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - prediction_distance=index, + prediction_distance=prediction_distance, ) - for index in range(1, self._config.prediction_heads) + for prediction_distance in range(2, self._config.prediction_heads + 1) ] ) @@ -88,8 +88,7 @@ def get_output_weights(self) -> list[torch.Tensor]: return sum((head.get_output_weights() for head in self.heads), []) def get_preprocessing_config(self, phase: PhaseType) -> dict[str, typing.Any]: - if self._enabled: - self._layers_with_namespace[0].get_preprocessing_config(phase) + return self._layers_with_namespace[0].get_preprocessing_config(phase) if self._enabled else {} def preprocess(self, kwargs: dict[str, typing.Any]) -> None: if self._enabled: diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 275a1fae9..f1df8059f 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -165,8 +165,9 @@ def _forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - sequence_length = kwargs[BlockKwargs.sequence_q_dim].size - token_shape = (1, kwargs[BlockKwargs.sequence_q_dim].size) + sequence_length = kwargs[BlockKwargs.token_dim].size + token_shape = (1, sequence_length) + # TODO: ====== Keep flat ====== # inner_projection : (local_tokens, hidden) -> (batch, sequence, local_inner_projection) inner_projection = self.in_proj(input_).unflatten(0, token_shape) dt = self.dt_proj(self.dt_in_proj(input_)).unflatten(0, token_shape) diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 0c58b7be5..05b6e4bbe 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -41,10 +41,10 @@ def get_converters( return super().get_converters(config, exported_config) + [ cls.normalization_converter_class.get_converters( config.head.normalization, - f"multi_token_prediction.heads.{prediction_distance - 1}.final_norm", - f"model.mtp_norms.{prediction_distance}", + f"multi_token_prediction.heads.{prediction_distance - 2}.final_norm", + f"model.mtp_norms.{prediction_distance-1}", ) - for prediction_distance in range(1, config.prediction_heads) + for prediction_distance in range(2, config.prediction_heads + 1) ] diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index fd02a6dc3..2bba685b9 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -121,7 +121,11 @@ def _inner_forward( kwargs_meta[BlockKwargs.output_hidden_states] = [re.compile(pattern) for pattern in output_hidden_states] ((input_, kwargs),) = self.fast_llm_base_model.preprocess_batch( - batch, [(input_meta, kwargs_meta)], phase=PhaseType.inference, iteration=iteration + batch, + [(input_meta, kwargs_meta)], + phase=PhaseType.inference, + iteration=iteration, + device=self.fast_llm_model.distributed.device, ) if past_key_values is not None: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 33519a415..58a0fa56d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -49,16 +49,15 @@ def preprocess_batch( iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, + device: torch.device | None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO Move batch splitting elsewhere, align interface with LayerBase - assert self._is_setup - reference_preprocessed_batches = {} for name, reference_model in self._reference_models.items(): reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( batch, phase=PhaseType.inference, iteration=iteration, + device=device, ) preprocessed = [] @@ -66,13 +65,14 @@ def preprocess_batch( for micro_sequence_index, micro_sequence in enumerate(batch.micro_batches): pasts = presents presents = None if micro_sequence_index == len(batch) - 1 else [] - micro_sequence.to_device_(self._distributed.device) + if device is not None: + micro_sequence.to_device_(device) kwargs: dict[str, typing.Any] = { LanguageModelKwargs.phase: phase, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, LanguageModelKwargs.iteration: iteration, - LanguageModelKwargs.device: self._distributed.device, + LanguageModelKwargs.device: device, LanguageModelKwargs.output_hidden_states: [], LanguageModelKwargs.hidden_states: {}, LanguageModelKwargs.token_dim: micro_sequence.token_dim, @@ -87,10 +87,8 @@ def preprocess_batch( AttentionKwargs.cu_seqlens_k: micro_sequence.cumulative_lengths_k, AttentionKwargs.max_seqlen_q: micro_sequence.max_length_q, AttentionKwargs.max_seqlen_k: micro_sequence.max_length_k, - LanguageModelKwargs.seq_idx: micro_sequence.document_index, + AttentionKwargs.seq_idx: micro_sequence.document_index, LanguageModelKwargs.position_ids: micro_sequence.position_index, - LanguageModelKwargs.chosen_spans: micro_sequence.chosen_spans, - LanguageModelKwargs.rejected_spans: micro_sequence.rejected_spans, } if extra_kwargs is not None: Assert.empty(kwargs.keys() & extra_kwargs.keys()) @@ -112,7 +110,8 @@ def preprocess_batch( layer_name: tensor for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() } - self.preprocess(kwargs) + if not micro_sequence.is_meta: + self.preprocess(kwargs) preprocessed.append((micro_sequence.tokens, kwargs)) return preprocessed @@ -140,10 +139,13 @@ def _head_reference_models(self) -> set[str]: class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): def get_preprocessing_config( - self, - phase: PhaseType, + self, batch: GPTBatchConfig, phase: PhaseType ) -> LanguageModelBatchPreprocessingConfig: - return LanguageModelBatchPreprocessingConfig(phase=phase, **self._base_model.get_preprocessing_config(phase)) + return LanguageModelBatchPreprocessingConfig( + phase=phase, + batch=batch, + **self._base_model.get_preprocessing_config(phase), + ) class GPTInferenceRunner(InferenceRunner): diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 7eb784148..2742032dd 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -151,9 +151,16 @@ def preprocess_batch( iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, + device: torch.device | None, ) -> list[tuple[torch.Tensor, dict]]: preprocessed = super().preprocess_batch( - batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics, extra_kwargs=extra_kwargs + batch, + preprocessed_meta, + phase=phase, + iteration=iteration, + metrics=metrics, + extra_kwargs=extra_kwargs, + device=device, ) # TODO: Support micro-sequences. assert len(preprocessed) == 1, "Micro-sequences not supported for MultiModalModel." diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py new file mode 100644 index 000000000..e8fd2f384 --- /dev/null +++ b/tests/data/test_preprocessing.py @@ -0,0 +1,65 @@ +import pytest +import torch + +from fast_llm.config import NoAutoValidate +from fast_llm.data.batch.config import LanguageModelBatchPreprocessingConfig +from fast_llm.data.batch.language_model import LanguageModelPreprocessedBatch +from fast_llm.data.document.language_model import LanguageModelDocument +from fast_llm.data.document.range import RangeDocument +from fast_llm.data.document.token import TokenDocument +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.utils import Assert + + +# TODO: Test padding, more scenarios +# TODO: Check rest of preprocessing output +@pytest.mark.parametrize( + ("tokens", "loss_masking_spans"), + ( + ([[100, 101, 102, 103, 104, 105, 106, 107]], [None]), # Simple case + ([[100, 101, -100, -100, 104, 105, 106, 107]], [None]), # Negative tokens + ([[100, 101, 102, 103, 104, 105, 106, 107]], [[(3, 5)]]), # Loss masking span + ([[100, 101, 102, 103, -100, -100, 106, 107]], [[(2, 3)]]), # Both + ( + [ + [100, 101, -100, 103, -100, -100, 106, 107], + [100, 101, 102, 103, 104, 105, 106, 107], + ], + [[(2, 3)], None], + ), # Two samples + ), +) +def test_preprocessing(tokens, loss_masking_spans): + documents = [ + LanguageModelDocument( + tokens=TokenDocument(tokens=torch.tensor(tokens_, dtype=torch.int64)), + loss_masking_spans=None if loss_masking_spans_ is None else RangeDocument(ranges=loss_masking_spans_), + ) + for tokens_, loss_masking_spans_ in zip(tokens, loss_masking_spans, strict=True) + ] + with NoAutoValidate(): + batch_config = GPTBatchConfig(sequence_length=sum(len(document) for document in documents) - 1) + batch_config.setup(DistributedConfig()) + batch_config.validate() + config = LanguageModelBatchPreprocessingConfig(batch=batch_config) + preprocessed = LanguageModelPreprocessedBatch.from_documents(documents, config) + + Assert.eq(len(preprocessed.micro_batches), 1) + micro_batch = preprocessed.micro_batches[0] + + Assert.all_equal(micro_batch.tokens, torch.cat([document.tokens.tokens for document in documents])[:-1]) + + label_tokens = [] + for document in documents: + label_tokens_ = document.tokens.tokens.clone() + # Mask cross-document attention + label_tokens_[0] = -100 + # Loss masking spans + if document.loss_masking_spans is not None: + for begin, end in document.loss_masking_spans.ranges: + label_tokens_[begin:end] = -100 + label_tokens.append(label_tokens_) + + Assert.eq(len(micro_batch.labels), 1) + Assert.all_equal(micro_batch.labels[0], torch.cat(label_tokens)[1:]) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c14232b4f..fe6128b6a 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -13,8 +13,7 @@ from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage -SEQUENCE_LENGTH = 200 -BATCH_SIZE = 4 +NUM_TOKENS = 200 HIDDEN_SIZE = 256 VOCAB_SIZE = 500 @@ -80,27 +79,35 @@ def get_config(self) -> GPTModelConfig: def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: device = "cuda" if torch.cuda.is_available() else "cpu" input_ = torch.randn( - (BATCH_SIZE * SEQUENCE_LENGTH, HIDDEN_SIZE), + (NUM_TOKENS, HIDDEN_SIZE), dtype=(torch.float32 if self.full_precision_residual else self.compute_dtype.torch), device=device, requires_grad=True, ) - label_shape = (BATCH_SIZE * (SEQUENCE_LENGTH + self.prediction_heads - 1),) kwargs: dict[str, typing.Any] = { AttentionKwargs.grad_output: 1.0, } if self.loss_masking: - kwargs[LanguageModelKwargs.loss_mask] = torch.randint(0, 2, label_shape, dtype=torch.bool, device=device) + kwargs[LanguageModelKwargs.loss_mask] = [ + torch.randint(0, 2, (NUM_TOKENS,), dtype=torch.bool, device=device) + for _ in range(self.prediction_heads) + ] if self.actual_label_loss is not False: - labels = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=device, - ) + labels = [ + torch.randint( + 0, + VOCAB_SIZE, + (NUM_TOKENS,), + dtype=torch.int64, + device=device, + ) + for _ in range(self.prediction_heads) + ] if LanguageModelKwargs.loss_mask in kwargs: - labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], labels, -100) + labels = [ + torch.where(mask, labels_, -100) + for labels_, mask in zip(labels, kwargs[LanguageModelKwargs.loss_mask], strict=True) + ] kwargs[LanguageModelKwargs.labels] = labels if self.distillation_loss is not False: @@ -138,13 +145,7 @@ def get_reference_outputs( losses = {} if self.actual_label_loss is not False: - labels = ( - kwargs[LanguageModelKwargs.labels] - .view(BATCH_SIZE, (SEQUENCE_LENGTH + self.prediction_heads - 1))[ - :, head._prediction_distance : head._prediction_distance + SEQUENCE_LENGTH - ] - .flatten() - ) + labels = kwargs[LanguageModelKwargs.labels][head._prediction_distance - 1] label_loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none").mean() losses["label"] = label_loss.detach() total_loss = total_loss + float(self.actual_label_loss) * label_loss @@ -156,7 +157,9 @@ def get_reference_outputs( reduction="none", ) if LanguageModelKwargs.loss_mask in kwargs: - distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask] + distillation_loss = ( + distillation_loss * kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] + ) distillation_loss = distillation_loss.mean() losses["distillation"] = distillation_loss.detach() total_loss = total_loss + float(self.distillation_loss) * distillation_loss @@ -164,7 +167,7 @@ def get_reference_outputs( if self.z_loss is not False: z_loss = torch.logsumexp(logits, dim=-1) ** 2 if LanguageModelKwargs.loss_mask in kwargs: - z_loss = z_loss * kwargs[LanguageModelKwargs.loss_mask] + z_loss = z_loss * kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] z_loss = z_loss.mean() losses["z_loss"] = z_loss.detach() total_loss = total_loss + float(self.z_loss) * z_loss @@ -176,7 +179,7 @@ def get_reference_outputs( else: losses = {LM_HEAD_LOSS_NAME: total_loss.detach()} - if head._prediction_distance > 0: + if head._prediction_distance > 1: losses = {f"{name}_{head._prediction_distance}": loss for name, loss in losses.items()} return total_loss.detach(), input_.grad, logit_weight.grad, normalization_weight.grad, losses @@ -236,12 +239,12 @@ def test_lm_head(test_config: LMHeadTestConfig): else None ) - for prediction_distance in range(model_config.base_model.head.prediction_heads): + for prediction_distance in range(1, model_config.base_model.head.prediction_heads + 1): # Prepare the LM head - head = model.head if prediction_distance == 0 else model.multi_token_prediction.heads[prediction_distance - 1] + head = model.head if prediction_distance == 1 else model.multi_token_prediction.heads[prediction_distance - 2] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - is_duplicate = test_config.tied_embedding_weight or prediction_distance > 0 + is_duplicate = test_config.tied_embedding_weight or prediction_distance > 1 stage = get_stage( [head], distributed, @@ -255,7 +258,7 @@ def test_lm_head(test_config: LMHeadTestConfig): ref_total_loss, ref_input_grad, ref_logit_weight_grad, ref_normalization_weight_grad, ref_losses = ( test_config.get_reference_outputs( - head, input_, kwargs, tied_logit_weight if prediction_distance > 0 else None + head, input_, kwargs, tied_logit_weight if prediction_distance > 1 else None ) ) diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py deleted file mode 100644 index f0af94256..000000000 --- a/tests/test_loss_mask.py +++ /dev/null @@ -1,254 +0,0 @@ -""" -Integration test that loss_mask correctly combines all masking sources: -- Negative labels (padding and image placeholders) -- loss_masking_spans - -Tests the actual preprocess_batch code path in fast_llm/models/gpt/model.py -""" - -import torch - -from fast_llm.config import NoAutoValidate -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.models.gpt.config import GPTBatchConfig, GPTModelConfig -from tests.utils.utils import get_base_model - - -def create_test_batch( - tokens: torch.Tensor, - lengths: list[list[int]] | None = None, - loss_masking_spans: list[list[tuple[int, int]]] | None = None, -) -> LanguageModelBatch: - """Create a LanguageModelBatch for testing.""" - token_batch = TokenBatch(tokens, lengths) - - if loss_masking_spans is not None: - range_batch = RangeBatch(loss_masking_spans, sample_size=tokens.shape[1]) - else: - range_batch = None - - return LanguageModelBatch( - tokens=token_batch, - loss_masking_spans=range_batch, - ) - - -def get_minimal_model(): - """Create a minimal GPT model for testing.""" - config = GPTModelConfig.from_dict( - { - "base_model": { - "decoder": {"num_blocks": 1}, - "embeddings": {"vocab_size": 1000}, - "hidden_size": 64, - }, - "distributed": {"use_cuda": torch.cuda.is_available()}, - }, - ) - model, distributed = get_base_model(config) - return model, distributed - - -def run_preprocess_batch(model, distributed_config, batch: LanguageModelBatch, phase: PhaseType = PhaseType.training): - """ - Run preprocess_batch with proper GPTBatchConfig metadata. - - This avoids the code path that accesses prediction_heads directly. - """ - micro_batch_size, sequence_length = batch.tokens.tokens.shape - - # Create GPTBatchConfig for metadata with proper setup - with NoAutoValidate(): - batch_config = GPTBatchConfig( - batch_size=micro_batch_size, - sequence_length=sequence_length, - ) - batch_config.setup(distributed_config) - batch_config.validate() - - # Get preprocessed metadata using GPTBatchConfig - preprocessed_meta = model.preprocess_meta(batch_config, phase) - - # Run preprocess_batch with the actual batch data - return model.preprocess_batch( - batch, - preprocessed_meta=preprocessed_meta, - phase=phase, - iteration=0, - ) - - -class TestLossMaskIntegration: - """ - Integration tests for loss_mask computation in preprocess_batch. - - These tests verify the masking behavior by checking labels, since: - 1. loss_mask = labels >= 0 (masks negative labels) - 2. loss_masking_spans positions are also masked - 3. labels are set to -100 at all masked positions - - So if labels are -100 at expected positions, the masking is working. - """ - - def test_negative_labels_preserved(self): - """Test that negative input tokens result in negative labels (shifted by 1).""" - model, distributed = get_minimal_model() - - # Sequence: [text, text, IMG(-100), IMG(-100), text, text, text, text] - # Labels (shifted by 1): [text, IMG, IMG, text, text, text, text, ?] - tokens = torch.tensor( - [ - [100, 101, -100, -100, 104, 105, 106, 107], - ], - dtype=torch.int64, - ) - - batch = create_test_batch(tokens) - preprocessed = run_preprocess_batch(model, distributed.config, batch) - - assert len(preprocessed) == 1 - _, kwargs = preprocessed[0] - - labels = kwargs[LanguageModelKwargs.labels] - # Flatten for easier indexing (handles sequence_first) - labels_flat = labels.flatten() - - # Labels at positions 1,2 should be -100 (the next token after positions 0,1 is -100) - assert labels_flat[1].item() == -100, f"Label at position 1 should be -100, got {labels_flat[1].item()}" - assert labels_flat[2].item() == -100, f"Label at position 2 should be -100, got {labels_flat[2].item()}" - - # Labels at other positions should be positive - assert labels_flat[0].item() > 0, "Label at position 0 should be positive" - assert labels_flat[3].item() > 0, "Label at position 3 should be positive" - - def test_loss_masking_spans_set_labels_to_negative(self): - """Test that loss_masking_spans positions have labels set to -100.""" - model, distributed = get_minimal_model() - - # All positive tokens - tokens = torch.tensor( - [ - [100, 101, 102, 103, 104, 105, 106, 107], - ], - dtype=torch.int64, - ) - - # loss_masking_spans are in TOKEN space, but labels are shifted by 1 - # Span (3, 5) in token space -> after cropping with labels_begin=1 -> (2, 4) in label space - # This will mask label positions 2 and 3 - loss_masking_spans = [[(3, 5)]] - - batch = create_test_batch(tokens, loss_masking_spans=loss_masking_spans) - preprocessed = run_preprocess_batch(model, distributed.config, batch) - - assert len(preprocessed) == 1 - _, kwargs = preprocessed[0] - - labels = kwargs[LanguageModelKwargs.labels] - labels_flat = labels.flatten() - - # After cropping, positions 2,3 in label space should be masked (set to -100) - assert labels_flat[2].item() == -100, f"Label at position 2 should be -100, got {labels_flat[2].item()}" - assert labels_flat[3].item() == -100, f"Label at position 3 should be -100, got {labels_flat[3].item()}" - - # Positions outside the span should be positive - assert labels_flat[0].item() > 0, "Label at position 0 should be positive" - assert labels_flat[1].item() > 0, "Label at position 1 should be positive" - assert labels_flat[4].item() > 0, "Label at position 4 should be positive" - - def test_combined_masking_negative_labels_and_spans(self): - """Test that both negative labels AND loss_masking_spans result in -100 labels.""" - model, distributed = get_minimal_model() - - # Tokens with -100 at positions 4,5 (will affect labels at 3,4) - tokens = torch.tensor( - [ - [100, 101, 102, 103, -100, -100, 106, 107], - ], - dtype=torch.int64, - ) - - # loss_masking_spans in token space: (2, 3) -> after cropping to label space: (1, 2) - # This will mask label position 1 - loss_masking_spans = [[(2, 3)]] - - batch = create_test_batch(tokens, loss_masking_spans=loss_masking_spans) - preprocessed = run_preprocess_batch(model, distributed.config, batch) - - assert len(preprocessed) == 1 - _, kwargs = preprocessed[0] - - labels = kwargs[LanguageModelKwargs.labels] - labels_flat = labels.flatten() - - # Position 1 should be -100 (from loss_masking_spans after cropping) - assert labels_flat[1].item() == -100, f"Position 1 should be -100 (from spans), got {labels_flat[1].item()}" - - # Positions 3,4 should be -100 (from negative input tokens at positions 4,5) - assert labels_flat[3].item() == -100, f"Position 3 should be -100 (from IMG), got {labels_flat[3].item()}" - assert labels_flat[4].item() == -100, f"Position 4 should be -100 (from IMG), got {labels_flat[4].item()}" - - # Position 0, 2, 5 should be positive (not masked) - assert labels_flat[0].item() > 0, "Position 0 should be positive" - assert labels_flat[2].item() > 0, "Position 2 should be positive" - assert labels_flat[5].item() > 0, "Position 5 should be positive" - - def test_all_padding_sample(self): - """Test that a sample with all -100 tokens (padding) results in all -100 labels.""" - model, distributed = get_minimal_model() - - # Sample 0: normal tokens - # Sample 1: all padding (-100) - tokens = torch.tensor( - [ - [100, 101, 102, 103, 104, 105, 106, 107], - [-100, -100, -100, -100, -100, -100, -100, -100], - ], - dtype=torch.int64, - ) - - batch = create_test_batch(tokens) - preprocessed = run_preprocess_batch(model, distributed.config, batch) - - assert len(preprocessed) == 1 - _, kwargs = preprocessed[0] - - labels = kwargs[LanguageModelKwargs.labels] - - # Get labels for sample 1 (all should be -100) - sample1_labels = labels[8:] - - assert torch.all(sample1_labels == -100), f"All labels in padding sample should be -100, got {sample1_labels}" - - def test_image_placeholders_interleaved(self): - """Test realistic scenario: text, image placeholders, text interleaved.""" - model, distributed = get_minimal_model() - - # Realistic sequence: [BOS, text, IMG, IMG, IMG, text, text, EOS] - # Labels should be: [text, IMG(-100), IMG(-100), IMG(-100), text, text, EOS, ?] - tokens = torch.tensor( - [ - [1, 100, -100, -100, -100, 200, 201, 2], - ], - dtype=torch.int64, - ) - - batch = create_test_batch(tokens) - preprocessed = run_preprocess_batch(model, distributed.config, batch) - - assert len(preprocessed) == 1 - _, kwargs = preprocessed[0] - - labels = kwargs[LanguageModelKwargs.labels] - labels_flat = labels.flatten() - - # Labels at positions 1,2,3 should be -100 (next tokens are IMG) - assert labels_flat[1].item() == -100, f"Position 1 should be -100, got {labels_flat[1].item()}" - assert labels_flat[2].item() == -100, f"Position 2 should be -100, got {labels_flat[2].item()}" - assert labels_flat[3].item() == -100, f"Position 3 should be -100, got {labels_flat[3].item()}" - - # Labels at positions 0, 4, 5 should be positive - assert labels_flat[0].item() > 0, f"Position 0 should be positive, got {labels_flat[0].item()}" - assert labels_flat[4].item() > 0, f"Position 4 should be positive, got {labels_flat[4].item()}" - assert labels_flat[5].item() > 0, f"Position 5 should be positive, got {labels_flat[5].item()}" diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index bd5a92720..7c17a107b 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -112,7 +112,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon name="bf16", compare="simple", # Also tests parallel data loader. - config_args=["model.distributed.compute_dtype=bf16", "training.num_workers=2"], + config_args=["model.distributed.compute_dtype=bf16", "training.num_workers=1"], num_gpus=1, compare_config=_bf16_compare, ),