Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 65 additions & 24 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,12 @@ def get_dataset_samples(
or a path to a ``.jsonl`` file. For local directory paths, the
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json``
builder and routed through the same auto-preprocess path as unregistered HF
datasets so chat / prompt / text columns are handled consistently with live
HF datasets. If that path fails (e.g. PyArrow schema unification across
heterogeneous rows), it falls back to a line-by-line reader that extracts
the legacy ``text`` field for backward compatibility.
num_samples: Number of samples to load from the dataset.
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset). For unregistered datasets with a
Expand All @@ -244,18 +250,23 @@ def get_dataset_samples(
Returns:
Samples: The list of samples.
"""
# Local JSONL file path support (each line is a JSON object with a `text` field).
if dataset_name.endswith(".jsonl"):
return get_jsonl_text_samples(dataset_name, num_samples, key="text")

from datasets import load_dataset

# Local JSONL: load via HF's ``json`` builder and route through the same
# auto-preprocess path as unregistered HF datasets so chat / prompt / text
# columns are handled consistently with a downloaded HF dataset. Never
# matches ``SUPPORTED_DATASET_CONFIG``.
is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name)

local_dataset_path = None
if os.path.exists(dataset_name): # Local path
local_dataset_path = dataset_name
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))
if not is_jsonl:
# Directory paths may match a registered key via their basename
# (e.g. /hf-local/abisee/cnn_dailymail -> cnn_dailymail).
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))

is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
is_registered = not is_jsonl and dataset_name in SUPPORTED_DATASET_CONFIG

if is_registered:
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
Expand Down Expand Up @@ -291,29 +302,57 @@ def _preprocess(sample: dict) -> str:
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
"Auto-detecting format from column names."
)
config = {"path": local_dataset_path or dataset_name}
if is_jsonl:
config = {"path": "json", "data_files": local_dataset_path}
else:
config = {"path": local_dataset_path or dataset_name}
# HF's file-based builders (incl. ``json``) label a string/list ``data_files``
# as the ``train`` split unconditionally — the filename on disk is ignored.
# Named splits require a dict ``data_files={"train": ..., "test": ...}``,
# which we don't expose here.
splits = _normalize_splits(split) if split is not None else ["train"]

def _preprocess(sample: dict) -> str:
return _auto_preprocess_sample(sample, dataset_name, tokenizer)

# load_dataset does not support a list of splits while streaming, so load each separately.
print(f"Loading dataset with {config=} and {splits=}")
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)

samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples
try:
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)

samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples
except Exception as e:
# Backward-compat fallback: legacy callers passed JSONL files whose only usable
# field is ``text``. If the HF ``json`` builder or auto-detect can't handle the
# file (schema inference error, unrecognized columns, etc.), fall back to a
# line-by-line reader that pulls the ``text`` field directly.
if is_jsonl:
assert local_dataset_path is not None # is_jsonl implies the path exists
try:
fallback_samples = get_jsonl_text_samples(
local_dataset_path, num_samples, key="text"
)
except Exception:
# Fallback can't help either — surface the original HF error.
raise e from None
warn(
f"Failed to load {local_dataset_path} via the HF 'json' builder "
f"({type(e).__name__}: {e}); fell back to legacy text-field reader."
)
return fallback_samples
raise


class _CustomDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -344,8 +383,10 @@ def get_dataset_dataloader(
"""Get a dataloader with the dataset name and tokenizer of the target model.

Args:
dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file.
If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field.
dataset_name: Name of the dataset to load, a path to a ``.jsonl`` file, or a list
mixing the two. Each entry is loaded via :func:`get_dataset_samples` and the
resulting samples are concatenated before tokenization. ``num_samples`` may be
an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
tokenizer: Instance of HuggingFace tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
Expand Down
Loading
Loading