diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index 22ab3d731..d39ce1e4a 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -14,4 +14,10 @@ GPTFimSampledDatasetConfig, GPTRandomDatasetConfig, ) +from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig # isort: skip from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip +from fast_llm.data.sample.abstract import NullReaderConfig # isort: skip +from fast_llm.data.sample.language_model import LanguageModelReaderConfig # isort: skip +from fast_llm.data.sample.patch import PatchReaderConfig # isort: skip +from fast_llm.data.sample.range import RangeReaderConfig # isort: skip +from fast_llm.data.sample.token import TokenReaderConfig # isort: skip diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index f80a48b0a..4b62d9d8a 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -18,6 +18,21 @@ class MemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): A memory map dataset, which handles lazy loading of a pre-processed dataset. """ + @staticmethod + def read_reader_config(path: pathlib.Path | str) -> MemmapIndexDatasetReaderConfig: + """ + Read the MemmapIndexDatasetReaderConfig from a memmap file. + """ + path = pathlib.Path(path) if isinstance(path, str) else path + with path.open("rb") as stream: + # Verify file type. + assert stream.read(len(FILE_HEADER)) == FILE_HEADER + # Go to reader configs. + stream.seek(int.from_bytes(stream.read(8), signed=False)) + # Read the reader config. + config_bytes = stream.read(int.from_bytes(stream.read(4), signed=False)) + return MemmapIndexDatasetReaderConfig.from_dict(json.loads(config_bytes.decode("utf-8"))) + def __init__( self, name: str, @@ -32,15 +47,7 @@ def _init(self, name: str, path: pathlib.Path | str, preprocessing: Preprocessin self._path = path self._preprocessing = preprocessing - with self._path.open("rb") as stream: - # Very file type. - assert stream.read(len(FILE_HEADER)) == FILE_HEADER - # Go to reader configs. - stream.seek(int.from_bytes(stream.read(8), signed=False)) - # Read the reader config. - reader_config = MemmapIndexDatasetReaderConfig.from_dict( - json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) - ) + reader_config = self.read_reader_config(self._path) self._memmap = np.memmap(self._path, mode="r") self._reader = reader_config.get_reader(memoryview(self._memmap), self._preprocessing) diff --git a/fast_llm/data/preparator/dataset_discovery/README.md b/fast_llm/data/preparator/dataset_discovery/README.md new file mode 100644 index 000000000..b88347f0d --- /dev/null +++ b/fast_llm/data/preparator/dataset_discovery/README.md @@ -0,0 +1,94 @@ +# Dataset Discovery + +Automatically discover `.fast_llm_dataset` files and generate a blended config with token-proportional weights. + +## Quick Start + +Using the tools wrapper: +```bash +python tools/discover_datasets.py -o +``` + +Using Fast-LLM CLI with config file: +```yaml +type: prepare_dataset_discovery +directory: /path/to/datasets +output: blended_dataset.yaml +ignore_paths: [test_data, checkpoints] # Optional +``` + +```bash +python -m fast_llm.cli --config config.yaml +``` + +## What It Does + +1. Scans directory tree for `.fast_llm_dataset` files +2. Reads token counts from binary file headers +3. Generates hierarchical blended config with automatic weights +4. Preserves directory structure + +## Example + +Input directory structure: +``` +datasets/ +├── domain_a/ +│ ├── shard_0.fast_llm_dataset (1B tokens) +│ └── shard_1.fast_llm_dataset (1B tokens) +└── domain_b/ + └── shard_0.fast_llm_dataset (4B tokens) +``` + +Generated config (`blended.yaml`): +```yaml +type: blended +name: datasets +datasets: + - type: blended + name: domain_a + datasets: + - type: memmap + path: datasets/domain_a/shard_0.fast_llm_dataset + - type: memmap + path: datasets/domain_a/shard_1.fast_llm_dataset + weights: [1.0, 1.0] + - type: memmap + path: datasets/domain_b/shard_0.fast_llm_dataset +weights: [2.0, 4.0] # In billions +``` + +Use in training: +```yaml +data: + datasets: + training: + type: file + path: blended.yaml +``` + +## Options + +- **directory**: Root directory to scan (required) +- **output**: Output YAML file path (required) +- **ignore_paths**: Paths to exclude, relative or absolute (optional) + +## Key Features + +- **Token-proportional sampling**: Datasets sampled by token count (larger datasets sampled more) +- **Hierarchical grouping**: Directory structure preserved in config +- **Automatic weights**: Calculated from binary file metadata +- **Error handling**: Skips unreadable files with warnings + +## Notes + +- Single datasets returned directly (not wrapped) +- Files with 0 tokens skipped with warning +- Empty directories raise error +- Datasets sorted alphabetically + +## Testing + +```bash +pytest tests/data/test_dataset_discovery.py +``` diff --git a/fast_llm/data/preparator/dataset_discovery/__init__.py b/fast_llm/data/preparator/dataset_discovery/__init__.py new file mode 100644 index 000000000..a9d38880a --- /dev/null +++ b/fast_llm/data/preparator/dataset_discovery/__init__.py @@ -0,0 +1,4 @@ +from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig +from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator + +__all__ = ["DatasetDiscoveryConfig", "DatasetDiscoveryPreparator"] diff --git a/fast_llm/data/preparator/dataset_discovery/config.py b/fast_llm/data/preparator/dataset_discovery/config.py new file mode 100644 index 000000000..d14b5bfd8 --- /dev/null +++ b/fast_llm/data/preparator/dataset_discovery/config.py @@ -0,0 +1,46 @@ +import pathlib +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.data.preparator.config import DatasetPreparatorConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig + +if typing.TYPE_CHECKING: + from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator + + +@config_class(dynamic_type={RunnableConfig: "prepare_dataset_discovery", DatasetPreparatorConfig: "dataset_discovery"}) +class DatasetDiscoveryConfig(DatasetPreparatorConfig): + """ + Configuration for the dataset discovery preparator. + + This preparator recursively discovers .fast_llm_dataset files in a directory + and generates a blended dataset config with weights proportional to token counts. + """ + + directory: pathlib.Path = Field( + desc="Directory to search for datasets recursively", + hint=FieldHint.core, + ) + output: pathlib.Path = Field( + desc="Output path for the generated config YAML file", + hint=FieldHint.core, + ) + ignore_paths: list[pathlib.Path] = Field( + default_factory=list, + desc="List of paths to ignore during dataset discovery (can be absolute or relative to directory)", + hint=FieldHint.optional, + ) + + def _validate(self) -> None: + super()._validate() + if not self.directory.exists(): + raise ValueError(f"Directory does not exist: {self.directory}") + if not self.directory.is_dir(): + raise ValueError(f"Path is not a directory: {self.directory}") + + @classmethod + def get_dataset_preparator_class(cls) -> type["DatasetDiscoveryPreparator"]: + from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator + + return DatasetDiscoveryPreparator diff --git a/fast_llm/data/preparator/dataset_discovery/prepare.py b/fast_llm/data/preparator/dataset_discovery/prepare.py new file mode 100644 index 000000000..25a29ca3e --- /dev/null +++ b/fast_llm/data/preparator/dataset_discovery/prepare.py @@ -0,0 +1,352 @@ +""" +Dataset discovery preparator. + +This module discovers datasets by directly scanning for .fast_llm_dataset files +and reading token counts from their binary headers. +""" + +import logging +import pathlib +from collections import defaultdict + +import yaml + +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.preparator.config import DatasetPreparator +from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig + +logger = logging.getLogger(__name__) + + +class DatasetDiscoveryPreparator[ConfigType: DatasetDiscoveryConfig](DatasetPreparator[ConfigType]): + """ + Preparator for discovering datasets by scanning .fast_llm_dataset files. + + Scans a directory tree for .fast_llm_dataset files and reads token counts + from their binary headers to generate a hierarchical blended config. + """ + + _config: DatasetDiscoveryConfig + + def run(self) -> None: + """ + Run the dataset discovery preparator. + """ + # Generate the hierarchical config by finding .fast_llm_dataset files + config = self._create_hierarchical_config( + self._config.directory.resolve(), + ignore_paths=self._config.ignore_paths, + ) + + # Write the config to the output file with header comment + self._config.output.parent.mkdir(parents=True, exist_ok=True) + with open(self._config.output, "w") as f: + # Write header comment + f.write( + "# This file was generated with fast_llm.data.preparator.dataset_discovery; " + "weights are token-counts in billions.\n" + ) + f.write(f"# Configuration:\n") + f.write(f"# directory: {self._config.directory}\n") + if self._config.ignore_paths: + f.write(f"# ignore_paths:\n") + for ignore_path in self._config.ignore_paths: + f.write(f"# - {ignore_path}\n") + f.write("\n") + # Write the YAML config + yaml.safe_dump(config, f, default_flow_style=False, sort_keys=False) + + logger.info(f"Generated dataset config saved to {self._config.output}") + + # Print a preview of the config + logger.info("\nGenerated config preview:") + preview = yaml.safe_dump(config, default_flow_style=False, sort_keys=False) + for line in preview.split("\n")[:50]: + logger.info(line) + + if len(preview.split("\n")) > 50: + logger.info("... (truncated)") + + @staticmethod + def _is_subpath(path: pathlib.Path, parent: pathlib.Path) -> bool: + """Check if path is under parent directory.""" + try: + path.relative_to(parent) + return True + except ValueError: + return False + + def _find_dataset_files( + self, root_dir: pathlib.Path, ignore_paths: list[pathlib.Path] | None = None + ) -> list[pathlib.Path]: + """ + Recursively find all .fast_llm_dataset files in the directory tree. + + Args: + root_dir: Root directory to search + ignore_paths: List of paths to ignore (can be absolute or relative to root_dir) + + Returns: + List of paths to .fast_llm_dataset files + """ + # Normalize ignore paths to absolute paths + ignore_paths_absolute = set() + if ignore_paths: + for ignore_path in ignore_paths: + if ignore_path.is_absolute(): + ignore_paths_absolute.add(ignore_path.resolve()) + else: + ignore_paths_absolute.add((root_dir / ignore_path).resolve()) + + # Find all .fast_llm_dataset files and filter out ignored ones + dataset_files = [] + for dataset_file in root_dir.rglob("*.fast_llm_dataset"): + dataset_file_resolved = dataset_file.resolve() + + # Check if this file is under any ignored path + is_ignored = any( + self._is_subpath(dataset_file_resolved, ignore_path) for ignore_path in ignore_paths_absolute + ) + + if not is_ignored: + dataset_files.append(dataset_file) + + # Sort by path for consistent ordering + return sorted(dataset_files) + + @staticmethod + def _read_memmap_num_tokens(memmap_path: pathlib.Path) -> int: + """Read number of tokens from a .fast_llm_dataset memmap file.""" + + if not memmap_path.exists(): + logger.warning(f"Memmap file not found: {memmap_path}") + return 0 + + try: + reader_config = MemmapDataset.read_reader_config(memmap_path) + return reader_config.num_tokens + except Exception as e: + logger.warning(f"Failed to read memmap file {memmap_path}: {e}") + return 0 + + def _get_token_count(self, dataset_file: pathlib.Path) -> float | None: + """ + Get token count in billions for a .fast_llm_dataset file. + + Returns: + Token count in billions, or None if the file couldn't be read + """ + num_tokens = self._read_memmap_num_tokens(dataset_file) + if num_tokens == 0: + logger.warning(f" - {dataset_file.name}: skipping (0 tokens or read error)") + return None + logger.debug(f" - {dataset_file.name}: {num_tokens:,} tokens") + return num_tokens / 1e9 + + def _create_memmap_config_for_dataset(self, dataset_file: pathlib.Path) -> dict: + """ + Create a memmap config dictionary for a .fast_llm_dataset file. + + Args: + dataset_file: Path to the .fast_llm_dataset file + + Returns: + Dictionary representing a memmap dataset config + """ + return {"type": "memmap", "path": str(dataset_file)} + + @staticmethod + def _get_directory_name(directory: pathlib.Path, root_dir: pathlib.Path, suffix: str = "") -> str: + """ + Generate a name for a directory relative to root. + + Args: + directory: The directory to name + root_dir: The root directory + suffix: Optional suffix to append to the name + + Returns: + A string name for the directory + """ + rel_path = directory.relative_to(root_dir) if directory != root_dir else pathlib.Path(".") + base_name = str(rel_path).replace("/", "_").replace(".", root_dir.name) + return f"{base_name}{suffix}" if suffix else base_name + + @staticmethod + def _group_files_by_directory(dataset_files: list[pathlib.Path]) -> dict[pathlib.Path, list[pathlib.Path]]: + """ + Group dataset files by their parent directory. + + Args: + dataset_files: List of dataset file paths + + Returns: + Dictionary mapping directory paths to lists of dataset files in that directory + """ + groups: dict[pathlib.Path, list[pathlib.Path]] = defaultdict(list) + for dataset_file in dataset_files: + groups[dataset_file.parent].append(dataset_file) + + return dict(groups) + + @staticmethod + def _build_directory_tree( + groups: dict[pathlib.Path, list[pathlib.Path]], root_dir: pathlib.Path + ) -> dict[pathlib.Path, set[pathlib.Path]]: + """ + Build a tree structure of directories showing parent-child relationships. + + Args: + groups: Dictionary mapping directories to their dataset files + root_dir: Root directory + + Returns: + Dictionary mapping each directory to its immediate child directories + """ + tree: dict[pathlib.Path, set[pathlib.Path]] = {root_dir: set()} + + for directory in groups.keys(): + # Add all ancestors to the tree + current = directory + while current != root_dir and current.parent != current: + parent = current.parent + if parent not in tree: + tree[parent] = set() + if current not in tree: + tree[current] = set() + tree[parent].add(current) + current = parent + + return tree + + def _create_directory_config( + self, + directory: pathlib.Path, + groups: dict[pathlib.Path, list[pathlib.Path]], + tree: dict[pathlib.Path, set[pathlib.Path]], + root_dir: pathlib.Path, + ) -> tuple[dict, float] | None: + """ + Recursively create a blended config for a directory and its subdirectories. + + Args: + directory: Current directory to process + groups: Dictionary mapping directories to their dataset files + tree: Directory tree structure + root_dir: Root directory + + Returns: + Tuple of (config dictionary, total token count in billions), or None if directory has no datasets + """ + local_datasets = [] + local_tokens = [] + + # Collect dataset files directly in this directory (not in subdirectories) + if directory in groups: + for dataset_file in sorted(groups[directory]): + token_count = self._get_token_count(dataset_file) + if token_count is not None: # Skip files that couldn't be read + local_datasets.append(self._create_memmap_config_for_dataset(dataset_file)) + local_tokens.append(token_count) + + # Recursively process subdirectories + subdir_datasets = [] + subdir_tokens = [] + if directory in tree: + for subdir in sorted(tree[directory]): + subdir_result = self._create_directory_config(subdir, groups, tree, root_dir) + if subdir_result is not None: + subdir_config, subdir_token_count = subdir_result + subdir_datasets.append(subdir_config) + subdir_tokens.append(subdir_token_count) + + # Combine local and subdirectory datasets + if local_datasets and subdir_datasets: + # If multiple local datasets, group them together + if len(local_datasets) > 1: + local_total_tokens = sum(local_tokens) + local_group = { + "type": "blended", + "name": self._get_directory_name(directory, root_dir, "_local"), + "datasets": local_datasets, + "weights": local_tokens, + } + all_datasets = [local_group] + subdir_datasets + all_tokens = [local_total_tokens] + subdir_tokens + else: + all_datasets = local_datasets + subdir_datasets + all_tokens = local_tokens + subdir_tokens + elif local_datasets: + all_datasets = local_datasets + all_tokens = local_tokens + elif subdir_datasets: + all_datasets = subdir_datasets + all_tokens = subdir_tokens + else: + return None + + total_tokens = sum(all_tokens) + + # Don't wrap a single dataset + if len(all_datasets) == 1: + return all_datasets[0], total_tokens + + # Multiple datasets - create blended config + return { + "type": "blended", + "name": self._get_directory_name(directory, root_dir), + "datasets": all_datasets, + "weights": all_tokens, + }, total_tokens + + def _create_hierarchical_config( + self, + root_dir: pathlib.Path, + ignore_paths: list[pathlib.Path] | None = None, + ) -> dict: + """ + Create a hierarchical blended dataset config from all .fast_llm_dataset files in a directory. + + Datasets in the same directory are grouped together with weights proportional to token counts, + and these groups are nested following the directory structure. + + Args: + root_dir: Root directory to search for datasets + ignore_paths: List of paths to ignore (can be absolute or relative to root_dir) + + Returns: + Dictionary representing the hierarchical blended dataset config + """ + logger.info(f"Discovering .fast_llm_dataset files in {root_dir}...") + + if ignore_paths: + logger.info(f"Ignoring {len(ignore_paths)} path(s):") + for ignore_path in ignore_paths: + logger.info(f" - {ignore_path}") + + dataset_files = self._find_dataset_files(root_dir, ignore_paths=ignore_paths) + + if not dataset_files: + raise ValueError(f"No .fast_llm_dataset files found in {root_dir}") + + logger.debug(f"Found {len(dataset_files)} dataset file(s):") + for dataset_file in dataset_files: + logger.debug(f" - {dataset_file.relative_to(root_dir)}") + + # Group dataset files by directory + groups = self._group_files_by_directory(dataset_files) + + # Build directory tree + tree = self._build_directory_tree(groups, root_dir) + + # Create hierarchical config + result = self._create_directory_config(root_dir, groups, tree, root_dir) + + if result is None: + raise ValueError("Failed to create config") + + config, total_tokens = result + + logger.info(f"Total tokens across all datasets: {total_tokens:.2f}B") + + return config diff --git a/tests/data/test_dataset_discovery.py b/tests/data/test_dataset_discovery.py new file mode 100644 index 000000000..dd8eeac46 --- /dev/null +++ b/tests/data/test_dataset_discovery.py @@ -0,0 +1,363 @@ +""" +Tests for the dataset discovery preparator. +""" + +import pathlib + +import pytest + +from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig +from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator + + +class TestDatasetDiscovery: + """Test dataset discovery that scans .fast_llm_dataset files.""" + + def test_find_dataset_files(self, tmp_path: pathlib.Path): + """Test finding .fast_llm_dataset files in directory tree.""" + # Create test directory structure + (tmp_path / "subdir1").mkdir() + (tmp_path / "subdir2").mkdir() + (tmp_path / "subdir1" / "nested").mkdir() + + # Create some .fast_llm_dataset files + (tmp_path / "dataset1.fast_llm_dataset").touch() + (tmp_path / "subdir1" / "dataset2.fast_llm_dataset").touch() + (tmp_path / "subdir1" / "nested" / "dataset3.fast_llm_dataset").touch() + (tmp_path / "subdir2" / "dataset4.fast_llm_dataset").touch() + + # Create some other files that should be ignored + (tmp_path / "readme.txt").touch() + (tmp_path / "subdir1" / "config.yaml").touch() + + # Create config + config = DatasetDiscoveryConfig( + directory=tmp_path, + output=tmp_path / "output.yaml", + ) + + # Create preparator + preparator = DatasetDiscoveryPreparator(config) + + # Find dataset files + dataset_files = preparator._find_dataset_files(tmp_path) + + # Should find all 4 .fast_llm_dataset files + assert len(dataset_files) == 4 + assert all(f.suffix == ".fast_llm_dataset" for f in dataset_files) + + def test_find_dataset_files_with_ignore(self, tmp_path: pathlib.Path): + """Test finding .fast_llm_dataset files with ignore paths.""" + # Create test directory structure + (tmp_path / "keep").mkdir() + (tmp_path / "ignore").mkdir() + + # Create dataset files + (tmp_path / "keep" / "dataset1.fast_llm_dataset").touch() + (tmp_path / "ignore" / "dataset2.fast_llm_dataset").touch() + + # Create config with ignore path + config = DatasetDiscoveryConfig( + directory=tmp_path, + output=tmp_path / "output.yaml", + ignore_paths=[pathlib.Path("ignore")], + ) + + # Create preparator + preparator = DatasetDiscoveryPreparator(config) + + # Find dataset files + dataset_files = preparator._find_dataset_files(tmp_path, ignore_paths=config.ignore_paths) + + # Should find only 1 file (dataset2 should be ignored) + assert len(dataset_files) == 1 + assert dataset_files[0].name == "dataset1.fast_llm_dataset" + + def test_group_files_by_directory(self, tmp_path: pathlib.Path): + """Test grouping dataset files by directory.""" + # Create files + files = [ + tmp_path / "dataset1.fast_llm_dataset", + tmp_path / "dataset2.fast_llm_dataset", + tmp_path / "subdir" / "dataset3.fast_llm_dataset", + ] + + # Group by directory + groups = DatasetDiscoveryPreparator._group_files_by_directory(files) + + # Should have 2 groups + assert len(groups) == 2 + assert len(groups[tmp_path]) == 2 + assert len(groups[tmp_path / "subdir"]) == 1 + + def test_build_directory_tree(self, tmp_path: pathlib.Path): + """Test building directory tree.""" + # Create nested directories + (tmp_path / "a" / "b" / "c").mkdir(parents=True) + + # Create groups + groups = { + tmp_path: [], + tmp_path / "a": [], + tmp_path / "a" / "b": [], + tmp_path / "a" / "b" / "c": [], + } + + # Build tree + tree = DatasetDiscoveryPreparator._build_directory_tree(groups, tmp_path) + + # Verify tree structure + assert tmp_path / "a" in tree[tmp_path] + assert tmp_path / "a" / "b" in tree[tmp_path / "a"] + assert tmp_path / "a" / "b" / "c" in tree[tmp_path / "a" / "b"] + + def test_create_memmap_config(self, tmp_path: pathlib.Path): + """Test creating memmap config for dataset file.""" + dataset_file = tmp_path / "dataset.fast_llm_dataset" + dataset_file.touch() + + config = DatasetDiscoveryConfig( + directory=tmp_path, + output=tmp_path / "output.yaml", + ) + preparator = DatasetDiscoveryPreparator(config) + + # Create config + memmap_config = preparator._create_memmap_config_for_dataset(dataset_file) + + # Verify config structure + assert memmap_config["type"] == "memmap" + assert memmap_config["path"] == str(dataset_file) + + def test_get_directory_name(self, tmp_path: pathlib.Path): + """Test directory naming.""" + root = tmp_path + subdir = tmp_path / "data" / "train" + + # Test root directory + name = DatasetDiscoveryPreparator._get_directory_name(root, root) + assert name == root.name + + # Test subdirectory + name = DatasetDiscoveryPreparator._get_directory_name(subdir, root) + assert name == "data_train" + + # Test with suffix + name = DatasetDiscoveryPreparator._get_directory_name(subdir, root, "_local") + assert name == "data_train_local" + + @pytest.mark.slow + def test_dataset_discovery_e2e_single_dataset(self, tmp_path: pathlib.Path): + """Test end-to-end discovery with a single dataset.""" + import shutil + + import yaml + + from tests.utils.dataset import get_common_test_dataset + + # Get a prepared test dataset + dataset_path, _, _, _ = get_common_test_dataset() + + # Copy the .fast_llm_dataset file to temp directory + dataset_files = list(dataset_path.glob("*.fast_llm_dataset")) + assert len(dataset_files) > 0, "No dataset files found in test dataset" + + test_dataset = dataset_files[0] + (tmp_path / "datasets").mkdir() + shutil.copy(test_dataset, tmp_path / "datasets" / "dataset.fast_llm_dataset") + + # Run dataset discovery + output_path = tmp_path / "discovered_config.yaml" + config = DatasetDiscoveryConfig( + directory=tmp_path / "datasets", + output=output_path, + ) + config.run() + + # Verify output file was created + assert output_path.exists() + + # Load and verify the generated config + with open(output_path) as f: + content = f.read() + # Check header comments + assert "# This file was generated with fast_llm.data.preparator.dataset_discovery" in content + assert "weights are token-counts in billions" in content + assert f"# directory: {tmp_path / 'datasets'}" in content + + # Parse YAML + f.seek(0) + generated_config = yaml.safe_load(f) + + # Single dataset should be returned directly (not blended) + assert generated_config["type"] == "memmap" + assert "dataset.fast_llm_dataset" in generated_config["path"] + + @pytest.mark.slow + def test_dataset_discovery_e2e_multiple_datasets(self, tmp_path: pathlib.Path): + """Test end-to-end discovery with multiple datasets in flat structure.""" + import shutil + + import yaml + + from tests.utils.dataset import get_alt_test_dataset, get_common_test_dataset + + # Get two different test datasets + dataset1_path, _, _, _ = get_common_test_dataset() + dataset2_path, _, _, _ = get_alt_test_dataset() + + # Copy dataset files to temp directory + (tmp_path / "datasets").mkdir() + dataset1_file = list(dataset1_path.glob("*.fast_llm_dataset"))[0] + dataset2_file = list(dataset2_path.glob("*.fast_llm_dataset"))[0] + + shutil.copy(dataset1_file, tmp_path / "datasets" / "dataset1.fast_llm_dataset") + shutil.copy(dataset2_file, tmp_path / "datasets" / "dataset2.fast_llm_dataset") + + # Run dataset discovery + output_path = tmp_path / "discovered_config.yaml" + config = DatasetDiscoveryConfig( + directory=tmp_path / "datasets", + output=output_path, + ) + config.run() + + # Verify output file was created + assert output_path.exists() + + # Load and verify the generated config + with open(output_path) as f: + generated_config = yaml.safe_load(f) + + # Multiple datasets should create a blended config + assert generated_config["type"] == "blended" + assert len(generated_config["datasets"]) == 2 + assert len(generated_config["weights"]) == 2 + + # Verify all weights are positive (in billions) + assert all(w > 0 for w in generated_config["weights"]) + + # Verify datasets are memmap configs + for dataset_config in generated_config["datasets"]: + assert dataset_config["type"] == "memmap" + assert "dataset" in dataset_config["path"] + + @pytest.mark.slow + def test_dataset_discovery_e2e_hierarchical_structure(self, tmp_path: pathlib.Path): + """Test end-to-end discovery with hierarchical directory structure.""" + import shutil + + import yaml + + from tests.utils.dataset import get_alt_test_dataset, get_common_test_dataset + + # Get test datasets + dataset1_path, _, _, _ = get_common_test_dataset() + dataset2_path, _, _, _ = get_alt_test_dataset() + + # Create hierarchical structure + (tmp_path / "root").mkdir() + (tmp_path / "root" / "group1").mkdir() + (tmp_path / "root" / "group2").mkdir() + + dataset1_file = list(dataset1_path.glob("*.fast_llm_dataset"))[0] + dataset2_file = list(dataset2_path.glob("*.fast_llm_dataset"))[0] + + # Place datasets in hierarchy + shutil.copy(dataset1_file, tmp_path / "root" / "dataset_a.fast_llm_dataset") + shutil.copy(dataset2_file, tmp_path / "root" / "dataset_b.fast_llm_dataset") + shutil.copy(dataset1_file, tmp_path / "root" / "group1" / "dataset_c.fast_llm_dataset") + shutil.copy(dataset2_file, tmp_path / "root" / "group2" / "dataset_d.fast_llm_dataset") + + # Run dataset discovery + output_path = tmp_path / "discovered_config.yaml" + config = DatasetDiscoveryConfig( + directory=tmp_path / "root", + output=output_path, + ) + config.run() + + # Load and verify the generated config + with open(output_path) as f: + generated_config = yaml.safe_load(f) + + # Should create hierarchical blended config + assert generated_config["type"] == "blended" + + # Root should have 3 items: local group + 2 subdirs + assert len(generated_config["datasets"]) == 3 + + # First item should be local datasets grouped with "_local" suffix + local_group = generated_config["datasets"][0] + assert local_group["type"] == "blended" + assert "_local" in local_group["name"] + assert len(local_group["datasets"]) == 2 + + # Next two should be subdirectory datasets (single dataset each, so memmap type) + # Check that one is from group1 and one from group2 + subdir_paths = [generated_config["datasets"][1]["path"], generated_config["datasets"][2]["path"]] + assert any("group1" in path for path in subdir_paths) + assert any("group2" in path for path in subdir_paths) + + @pytest.mark.slow + def test_dataset_discovery_e2e_with_ignore_paths(self, tmp_path: pathlib.Path): + """Test end-to-end discovery with ignore_paths.""" + import shutil + + import yaml + + from tests.utils.dataset import get_common_test_dataset + + # Get test dataset + dataset_path, _, _, _ = get_common_test_dataset() + dataset_file = list(dataset_path.glob("*.fast_llm_dataset"))[0] + + # Create directory structure + (tmp_path / "datasets" / "keep").mkdir(parents=True) + (tmp_path / "datasets" / "ignore").mkdir(parents=True) + + # Place datasets + shutil.copy(dataset_file, tmp_path / "datasets" / "keep" / "dataset1.fast_llm_dataset") + shutil.copy(dataset_file, tmp_path / "datasets" / "ignore" / "dataset2.fast_llm_dataset") + + # Run dataset discovery with ignore_paths + output_path = tmp_path / "discovered_config.yaml" + config = DatasetDiscoveryConfig( + directory=tmp_path / "datasets", + output=output_path, + ignore_paths=[pathlib.Path("ignore")], + ) + config.run() + + # Load and verify the generated config + with open(output_path) as f: + content = f.read() + # Check ignore_paths in header + assert "ignore_paths:" in content + assert "ignore" in content + + # Parse YAML + f.seek(0) + generated_config = yaml.safe_load(f) + + # Should only include the dataset from "keep" directory + # Single dataset, so should be memmap (not blended) + assert generated_config["type"] == "memmap" + assert "keep" in generated_config["path"] + assert "ignore" not in generated_config["path"] + + @pytest.mark.slow + def test_dataset_discovery_e2e_empty_directory(self, tmp_path: pathlib.Path): + """Test that discovery fails gracefully on empty directory.""" + # Create empty directory + (tmp_path / "empty").mkdir() + + # Run dataset discovery - should raise ValueError + output_path = tmp_path / "output.yaml" + config = DatasetDiscoveryConfig( + directory=tmp_path / "empty", + output=output_path, + ) + + with pytest.raises(ValueError, match="No .fast_llm_dataset files found"): + config.run()