Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fast_llm/data/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@
GPTFimSampledDatasetConfig,
GPTRandomDatasetConfig,
)
from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig # isort: skip
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
from fast_llm.data.sample.abstract import NullReaderConfig # isort: skip
from fast_llm.data.sample.language_model import LanguageModelReaderConfig # isort: skip
from fast_llm.data.sample.patch import PatchReaderConfig # isort: skip
from fast_llm.data.sample.range import RangeReaderConfig # isort: skip
from fast_llm.data.sample.token import TokenReaderConfig # isort: skip
25 changes: 16 additions & 9 deletions fast_llm/data/dataset/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ class MemmapDataset[SampleType: Sample](IndexedDataset[SampleType]):
A memory map dataset, which handles lazy loading of a pre-processed dataset.
"""

@staticmethod
def read_reader_config(path: pathlib.Path | str) -> MemmapIndexDatasetReaderConfig:
"""
Read the MemmapIndexDatasetReaderConfig from a memmap file.
"""
path = pathlib.Path(path) if isinstance(path, str) else path
with path.open("rb") as stream:
# Verify file type.
assert stream.read(len(FILE_HEADER)) == FILE_HEADER
# Go to reader configs.
stream.seek(int.from_bytes(stream.read(8), signed=False))
# Read the reader config.
config_bytes = stream.read(int.from_bytes(stream.read(4), signed=False))
return MemmapIndexDatasetReaderConfig.from_dict(json.loads(config_bytes.decode("utf-8")))

def __init__(
self,
name: str,
Expand All @@ -32,15 +47,7 @@ def _init(self, name: str, path: pathlib.Path | str, preprocessing: Preprocessin
self._path = path
self._preprocessing = preprocessing

with self._path.open("rb") as stream:
# Very file type.
assert stream.read(len(FILE_HEADER)) == FILE_HEADER
# Go to reader configs.
stream.seek(int.from_bytes(stream.read(8), signed=False))
# Read the reader config.
reader_config = MemmapIndexDatasetReaderConfig.from_dict(
json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8"))
)
reader_config = self.read_reader_config(self._path)

self._memmap = np.memmap(self._path, mode="r")
self._reader = reader_config.get_reader(memoryview(self._memmap), self._preprocessing)
Expand Down
94 changes: 94 additions & 0 deletions fast_llm/data/preparator/dataset_discovery/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Dataset Discovery

Automatically discover `.fast_llm_dataset` files and generate a blended config with token-proportional weights.

## Quick Start

Using the tools wrapper:
```bash
python tools/discover_datasets.py <directory> -o <output.yaml>
```

Using Fast-LLM CLI with config file:
```yaml
type: prepare_dataset_discovery
directory: /path/to/datasets
output: blended_dataset.yaml
ignore_paths: [test_data, checkpoints] # Optional
```

```bash
python -m fast_llm.cli --config config.yaml
```

## What It Does

1. Scans directory tree for `.fast_llm_dataset` files
2. Reads token counts from binary file headers
3. Generates hierarchical blended config with automatic weights
4. Preserves directory structure

## Example

Input directory structure:
```
datasets/
β”œβ”€β”€ domain_a/
β”‚ β”œβ”€β”€ shard_0.fast_llm_dataset (1B tokens)
β”‚ └── shard_1.fast_llm_dataset (1B tokens)
└── domain_b/
└── shard_0.fast_llm_dataset (4B tokens)
```

Generated config (`blended.yaml`):
```yaml
type: blended
name: datasets
datasets:
- type: blended
name: domain_a
datasets:
- type: memmap
path: datasets/domain_a/shard_0.fast_llm_dataset
- type: memmap
path: datasets/domain_a/shard_1.fast_llm_dataset
weights: [1.0, 1.0]
- type: memmap
path: datasets/domain_b/shard_0.fast_llm_dataset
weights: [2.0, 4.0] # In billions
```

Use in training:
```yaml
data:
datasets:
training:
type: file
path: blended.yaml
```

## Options

- **directory**: Root directory to scan (required)
- **output**: Output YAML file path (required)
- **ignore_paths**: Paths to exclude, relative or absolute (optional)

## Key Features

- **Token-proportional sampling**: Datasets sampled by token count (larger datasets sampled more)
- **Hierarchical grouping**: Directory structure preserved in config
- **Automatic weights**: Calculated from binary file metadata
- **Error handling**: Skips unreadable files with warnings

## Notes

- Single datasets returned directly (not wrapped)
- Files with 0 tokens skipped with warning
- Empty directories raise error
- Datasets sorted alphabetically

## Testing

```bash
pytest tests/data/test_dataset_discovery.py
```
4 changes: 4 additions & 0 deletions fast_llm/data/preparator/dataset_discovery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig
from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator

__all__ = ["DatasetDiscoveryConfig", "DatasetDiscoveryPreparator"]
46 changes: 46 additions & 0 deletions fast_llm/data/preparator/dataset_discovery/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pathlib
import typing

from fast_llm.config import Field, FieldHint, config_class
from fast_llm.data.preparator.config import DatasetPreparatorConfig
from fast_llm.engine.config_utils.runnable import RunnableConfig

if typing.TYPE_CHECKING:
from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator


@config_class(dynamic_type={RunnableConfig: "prepare_dataset_discovery", DatasetPreparatorConfig: "dataset_discovery"})
class DatasetDiscoveryConfig(DatasetPreparatorConfig):
"""
Configuration for the dataset discovery preparator.

This preparator recursively discovers .fast_llm_dataset files in a directory
and generates a blended dataset config with weights proportional to token counts.
"""

directory: pathlib.Path = Field(
desc="Directory to search for datasets recursively",
hint=FieldHint.core,
)
output: pathlib.Path = Field(
desc="Output path for the generated config YAML file",
hint=FieldHint.core,
)
ignore_paths: list[pathlib.Path] = Field(
default_factory=list,
desc="List of paths to ignore during dataset discovery (can be absolute or relative to directory)",
hint=FieldHint.optional,
)

def _validate(self) -> None:
super()._validate()
if not self.directory.exists():
raise ValueError(f"Directory does not exist: {self.directory}")
if not self.directory.is_dir():
raise ValueError(f"Path is not a directory: {self.directory}")

@classmethod
def get_dataset_preparator_class(cls) -> type["DatasetDiscoveryPreparator"]:
from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator

return DatasetDiscoveryPreparator
Loading
Loading