diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 01cb3abe88..3ea345c09f 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -230,6 +230,12 @@ def get_dataset_samples( or a path to a ``.jsonl`` file. For local directory paths, the predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key). + For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json`` + builder and routed through the same auto-preprocess path as unregistered HF + datasets so chat / prompt / text columns are handled consistently with live + HF datasets. If that path fails (e.g. PyArrow schema unification across + heterogeneous rows), it falls back to a line-by-line reader that extracts + the legacy ``text`` field for backward compatibility. num_samples: Number of samples to load from the dataset. apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). For unregistered datasets with a @@ -244,18 +250,23 @@ def get_dataset_samples( Returns: Samples: The list of samples. """ - # Local JSONL file path support (each line is a JSON object with a `text` field). - if dataset_name.endswith(".jsonl"): - return get_jsonl_text_samples(dataset_name, num_samples, key="text") - from datasets import load_dataset + # Local JSONL: load via HF's ``json`` builder and route through the same + # auto-preprocess path as unregistered HF datasets so chat / prompt / text + # columns are handled consistently with a downloaded HF dataset. Never + # matches ``SUPPORTED_DATASET_CONFIG``. + is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name) + local_dataset_path = None if os.path.exists(dataset_name): # Local path local_dataset_path = dataset_name - dataset_name = os.path.basename(os.path.normpath(local_dataset_path)) + if not is_jsonl: + # Directory paths may match a registered key via their basename + # (e.g. /hf-local/abisee/cnn_dailymail -> cnn_dailymail). + dataset_name = os.path.basename(os.path.normpath(local_dataset_path)) - is_registered = dataset_name in SUPPORTED_DATASET_CONFIG + is_registered = not is_jsonl and dataset_name in SUPPORTED_DATASET_CONFIG if is_registered: dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name] @@ -291,7 +302,14 @@ def _preprocess(sample: dict) -> str: f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. " "Auto-detecting format from column names." ) - config = {"path": local_dataset_path or dataset_name} + if is_jsonl: + config = {"path": "json", "data_files": local_dataset_path} + else: + config = {"path": local_dataset_path or dataset_name} + # HF's file-based builders (incl. ``json``) label a string/list ``data_files`` + # as the ``train`` split unconditionally — the filename on disk is ignored. + # Named splits require a dict ``data_files={"train": ..., "test": ...}``, + # which we don't expose here. splits = _normalize_splits(split) if split is not None else ["train"] def _preprocess(sample: dict) -> str: @@ -299,21 +317,42 @@ def _preprocess(sample: dict) -> str: # load_dataset does not support a list of splits while streaming, so load each separately. print(f"Loading dataset with {config=} and {splits=}") - dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits] - - num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits) - num_per_split[-1] += num_samples - sum(num_per_split) - - samples: list[str] = [] - for dataset, n in zip(dataset_splits, num_per_split): - for i, sample in enumerate(dataset): - if i >= n: - break - text = _preprocess(sample) - if text: - samples.append(text) - - return samples + try: + dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits] + + num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits) + num_per_split[-1] += num_samples - sum(num_per_split) + + samples: list[str] = [] + for dataset, n in zip(dataset_splits, num_per_split): + for i, sample in enumerate(dataset): + if i >= n: + break + text = _preprocess(sample) + if text: + samples.append(text) + + return samples + except Exception as e: + # Backward-compat fallback: legacy callers passed JSONL files whose only usable + # field is ``text``. If the HF ``json`` builder or auto-detect can't handle the + # file (schema inference error, unrecognized columns, etc.), fall back to a + # line-by-line reader that pulls the ``text`` field directly. + if is_jsonl: + assert local_dataset_path is not None # is_jsonl implies the path exists + try: + fallback_samples = get_jsonl_text_samples( + local_dataset_path, num_samples, key="text" + ) + except Exception: + # Fallback can't help either — surface the original HF error. + raise e from None + warn( + f"Failed to load {local_dataset_path} via the HF 'json' builder " + f"({type(e).__name__}: {e}); fell back to legacy text-field reader." + ) + return fallback_samples + raise class _CustomDataset(torch.utils.data.Dataset): @@ -344,8 +383,10 @@ def get_dataset_dataloader( """Get a dataloader with the dataset name and tokenizer of the target model. Args: - dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file. - If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field. + dataset_name: Name of the dataset to load, a path to a ``.jsonl`` file, or a list + mixing the two. Each entry is loaded via :func:`get_dataset_samples` and the + resulting samples are concatenated before tokenization. ``num_samples`` may be + an ``int`` (applied to a single source) or a list aligned with ``dataset_name``. tokenizer: Instance of HuggingFace tokenizer. batch_size: Batch size of the returned dataloader. num_samples: Number of samples from the dataset. diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 9a89d53672..fbaa29e290 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -18,7 +18,11 @@ import pytest import torch -from modelopt.torch.utils.dataset_utils import _process_batch, get_dataset_samples +from modelopt.torch.utils.dataset_utils import ( + _process_batch, + get_dataset_dataloader, + get_dataset_samples, +) def setup_test_data(): @@ -167,3 +171,336 @@ def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_lo assert isinstance(samples, list) assert len(samples) == 5 assert all(isinstance(s, str) and len(s) > 0 for s in samples) + + +# --------------------------------------------------------------------------- +# Local JSONL loading — must flow through the same auto-preprocess path as a +# downloaded HF dataset, so chat / prompt / text columns are all handled. +# --------------------------------------------------------------------------- + + +def _write_jsonl(path, rows): + """Write a list of dicts to *path* as JSONL. Returns the path as ``str``.""" + import json + + with open(path, "w", encoding="utf-8") as f: + f.writelines(json.dumps(row) + "\n" for row in rows) + return str(path) + + +@pytest.fixture +def chat_tokenizer(): + """Mock tokenizer whose ``apply_chat_template`` joins messages role:content.""" + tok = Mock() + tok.apply_chat_template = Mock( + side_effect=lambda msgs, tokenize=False, **kw: " | ".join( + f"{m['role']}:{m['content']}" for m in msgs + ) + ) + return tok + + +class TestLocalJsonlLoading: + """Local ``.jsonl`` paths route through HF's ``json`` builder + auto-preprocess.""" + + def test_text_column(self, tmp_path): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "plain.jsonl", + [{"text": f"plain {i}"} for i in range(3)], + ) + samples = get_dataset_samples(path, num_samples=3) + assert samples == ["plain 0", "plain 1", "plain 2"] + + def test_messages_column_uses_chat_template(self, tmp_path, chat_tokenizer): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "chat.jsonl", + [ + { + "messages": [ + {"role": "user", "content": f"hello {i}"}, + {"role": "assistant", "content": f"hi {i}"}, + ] + } + for i in range(3) + ], + ) + samples = get_dataset_samples(path, num_samples=3, tokenizer=chat_tokenizer) + assert len(samples) == 3 + assert samples[0] == "user:hello 0 | assistant:hi 0" + # apply_chat_template must have been invoked once per sample + assert chat_tokenizer.apply_chat_template.call_count == 3 + + def test_conversations_column_uses_chat_template(self, tmp_path, chat_tokenizer): + """Auto-detect also recognizes ``conversations`` (Magpie-style).""" + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "convs.jsonl", + [ + { + "conversations": [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "a"}, + ] + } + ], + ) + samples = get_dataset_samples(path, num_samples=1, tokenizer=chat_tokenizer) + assert samples == ["user:q | assistant:a"] + + def test_prompt_completion_concatenated(self, tmp_path): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "prompt.jsonl", + [{"prompt": "Q?", "completion": "A."}], + ) + samples = get_dataset_samples(path, num_samples=1) + assert samples == ["Q?\nA."] + + def test_input_output_concatenated(self, tmp_path): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "io.jsonl", + [{"input": "in", "output": "out"}], + ) + samples = get_dataset_samples(path, num_samples=1) + assert samples == ["in\nout"] + + def test_num_samples_honored(self, tmp_path): + """Loads only the requested number of rows even when the file is larger.""" + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "many.jsonl", + [{"text": f"row {i}"} for i in range(100)], + ) + samples = get_dataset_samples(path, num_samples=5) + assert len(samples) == 5 + assert samples == [f"row {i}" for i in range(5)] + + def test_tools_forwarded_to_chat_template(self, tmp_path, chat_tokenizer): + """If a row carries a ``tools`` field, it's passed through to apply_chat_template.""" + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "tools.jsonl", + [ + { + "messages": [{"role": "user", "content": "x"}], + "tools": [{"name": "calc"}], + } + ], + ) + get_dataset_samples(path, num_samples=1, tokenizer=chat_tokenizer) + _, kwargs = chat_tokenizer.apply_chat_template.call_args + assert kwargs.get("tools") == [{"name": "calc"}] + + def test_unrecognized_columns_raise(self, tmp_path): + """Auto-detect raises ValueError when no recognized column is present. + + The HF builder loads the rows fine; auto-detect rejects them. There's no + ``text`` field to fall back to, so the error propagates instead of being + masked by the legacy fallback. + """ + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "bad.jsonl", + [{"unrelated_field": "value"}], + ) + with pytest.raises(ValueError, match="Cannot auto-detect format"): + get_dataset_samples(path, num_samples=1) + + def test_legacy_text_fallback_on_hf_builder_failure(self, tmp_path): + """If the HF json builder raises, fall back to the legacy text-field reader.""" + pytest.importorskip("datasets") + # Mixed-type ``meta`` field across rows — int vs string — trips PyArrow + # schema unification in the HF json builder. The rows still carry a + # ``text`` field, so the legacy reader can recover the samples. + rows = [ + {"text": "row a", "meta": 1}, + {"text": "row b", "meta": "two"}, + {"text": "row c", "meta": 3}, + ] + path = _write_jsonl(tmp_path / "mixed.jsonl", rows) + samples = get_dataset_samples(path, num_samples=3) + assert samples == ["row a", "row b", "row c"] + + +# --------------------------------------------------------------------------- +# get_dataset_dataloader — blending across multiple sources +# --------------------------------------------------------------------------- + + +class _FakeTokenizer: + """Minimal callable tokenizer that mimics the HF tokenizer surface used by the dataloader. + + Tokenizes by character ordinal and left-pads to the longest sample (capped at max_length). + Avoids a hard dependency on ``transformers`` in the test environment. + """ + + padding_side = "left" + pad_token_id = 0 + + def __call__(self, texts, return_tensors=None, padding=True, truncation=True, max_length=16): + ids = [[ord(c) % 100 + 1 for c in t][:max_length] for t in texts] + n = max(len(x) for x in ids) + input_ids = [[self.pad_token_id] * (n - len(x)) + x for x in ids] + attention = [[0] * (n - len(x)) + [1] * len(x) for x in ids] + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention, dtype=torch.long), + } + + +@pytest.fixture +def pad_tokenizer(): + return _FakeTokenizer() + + +class TestGetDatasetDataloaderBlending: + """``get_dataset_dataloader`` accepts a list of sources and concatenates them.""" + + def test_single_jsonl(self, tmp_path, pad_tokenizer): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "single.jsonl", + [{"text": f"row {i}"} for i in range(4)], + ) + loader = get_dataset_dataloader( + dataset_name=path, + tokenizer=pad_tokenizer, + batch_size=2, + num_samples=4, + max_sample_length=16, + ) + batches = list(loader) + assert len(batches) == 2 + assert batches[0]["input_ids"].shape[0] == 2 + assert "attention_mask" in batches[0] + + def test_list_of_jsonl_blends(self, tmp_path, pad_tokenizer): + """Two local JSONL files concatenated into a single dataloader.""" + pytest.importorskip("datasets") + a = _write_jsonl(tmp_path / "a.jsonl", [{"text": f"a{i}"} for i in range(3)]) + b = _write_jsonl(tmp_path / "b.jsonl", [{"text": f"b{i}"} for i in range(2)]) + + loader = get_dataset_dataloader( + dataset_name=[a, b], + tokenizer=pad_tokenizer, + batch_size=5, + num_samples=[3, 2], + max_sample_length=16, + ) + batches = list(loader) + assert len(batches) == 1 + assert batches[0]["input_ids"].shape[0] == 5 + + def test_mixed_formats_blended(self, tmp_path, pad_tokenizer): + """Mixing a text-column JSONL with a prompt/completion JSONL — both should flow.""" + pytest.importorskip("datasets") + plain = _write_jsonl(tmp_path / "plain.jsonl", [{"text": "hello"}]) + pc = _write_jsonl(tmp_path / "pc.jsonl", [{"prompt": "Q?", "completion": "A."}]) + + loader = get_dataset_dataloader( + dataset_name=[plain, pc], + tokenizer=pad_tokenizer, + batch_size=2, + num_samples=[1, 1], + max_sample_length=16, + ) + batches = list(loader) + assert len(batches) == 1 + assert batches[0]["input_ids"].shape[0] == 2 + + def test_length_mismatch_raises(self, tmp_path, pad_tokenizer): + """``dataset_name`` and ``num_samples`` lists must align.""" + pytest.importorskip("datasets") + a = _write_jsonl(tmp_path / "a.jsonl", [{"text": "x"}]) + b = _write_jsonl(tmp_path / "b.jsonl", [{"text": "y"}]) + with pytest.raises(AssertionError, match="same length"): + get_dataset_dataloader( + dataset_name=[a, b], + tokenizer=pad_tokenizer, + num_samples=[1], + max_sample_length=16, + ) + + +# --------------------------------------------------------------------------- +# Live HF dataset round-trips. ``hf-internal-testing/dataset_with_data_files`` +# is a 10-row x {train,test} fixture maintained by HF for their own CI — tiny +# enough to download in a unit test and stable across releases. +# --------------------------------------------------------------------------- + +_HF_TINY = "hf-internal-testing/dataset_with_data_files" # train, test splits, ``text`` col + + +def _hf_dump_to_jsonl(name: str, split: str, path) -> str: + from datasets import load_dataset + + ds = load_dataset(name, split=split) + ds.to_json(str(path), lines=True) + return str(path) + + +class TestHfTinyDataset: + """End-to-end coverage with a real (tiny) HF dataset.""" + + def test_load_single_split_directly(self): + pytest.importorskip("datasets") + samples = get_dataset_samples(_HF_TINY, num_samples=4, split="train") + assert len(samples) == 4 + assert all(isinstance(s, str) and s for s in samples) + + def test_load_multiple_splits_directly(self): + """``split=["train", "test"]`` divides ``num_samples`` across both splits.""" + pytest.importorskip("datasets") + samples = get_dataset_samples(_HF_TINY, num_samples=6, split=["train", "test"]) + assert len(samples) == 6 + # Default per-split is num_samples // n + remainder; for 6/2 → 3 from each. + # We can't assert exact origin without re-reading, but both splits should + # contribute, which we'll confirm by comparing against direct loads below. + train_only = set(get_dataset_samples(_HF_TINY, num_samples=10, split="train")) + test_only = set(get_dataset_samples(_HF_TINY, num_samples=10, split="test")) + assert any(s in train_only for s in samples) + assert any(s in test_only for s in samples) + + def test_default_split_is_train(self): + pytest.importorskip("datasets") + default_samples = get_dataset_samples(_HF_TINY, num_samples=4) + train_samples = get_dataset_samples(_HF_TINY, num_samples=4, split="train") + assert default_samples == train_samples + + def test_download_to_jsonl_then_load(self, tmp_path): + """Dump the HF dataset to JSONL, then reload it via the local-jsonl path.""" + pytest.importorskip("datasets") + jsonl_path = _hf_dump_to_jsonl(_HF_TINY, "train", tmp_path / "train.jsonl") + from_jsonl = get_dataset_samples(jsonl_path, num_samples=10) + from_hf = get_dataset_samples(_HF_TINY, num_samples=10, split="train") + assert from_jsonl == from_hf + + def test_dataloader_blending_two_hf_datasets(self, pad_tokenizer): + """Two HF datasets concatenated via ``get_dataset_dataloader``.""" + pytest.importorskip("datasets") + loader = get_dataset_dataloader( + dataset_name=[_HF_TINY, "hf-internal-testing/multi_dir_dataset"], + tokenizer=pad_tokenizer, + batch_size=4, + num_samples=[3, 1], + max_sample_length=16, + ) + batches = list(loader) + assert sum(b["input_ids"].shape[0] for b in batches) == 4 + + def test_dataloader_mixing_hf_and_local_jsonl(self, tmp_path, pad_tokenizer): + """Live HF dataset blended with a local synthetic JSONL file.""" + pytest.importorskip("datasets") + local = _write_jsonl(tmp_path / "local.jsonl", [{"text": f"local {i}"} for i in range(2)]) + loader = get_dataset_dataloader( + dataset_name=[_HF_TINY, local], + tokenizer=pad_tokenizer, + batch_size=5, + num_samples=[3, 2], + max_sample_length=16, + ) + batches = list(loader) + assert sum(b["input_ids"].shape[0] for b in batches) == 5