Skip to content

Commit 534e634

Browse files
authored
Dedup exact lines training tokenizer dataset (#409)
1 parent 4560406 commit 534e634

10 files changed

+651
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Taken from Teven and Leandro"""
2+
import gzip
3+
import os
4+
import shutil
5+
import time
6+
import logging
7+
import argparse
8+
9+
from datasets import load_from_disk
10+
from datasets.utils.logging import set_verbosity_info
11+
12+
13+
set_verbosity_info()
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def get_args():
18+
parser = argparse.ArgumentParser(description="Load seed and upload to hub")
19+
parser.add_argument(
20+
"--save-dir", required=True, type=str, help="Where to save the datasets."
21+
)
22+
parser.add_argument(
23+
"--dataset_dir",
24+
help="path to where the arrow dataset is located",
25+
required=True,
26+
type=str,
27+
)
28+
parser.add_argument(
29+
"--batch-size",
30+
help="Batch size used for the mapping and saving of the dataset",
31+
required=True,
32+
type=int,
33+
)
34+
parser.add_argument(
35+
"--num-proc",
36+
help="Number of processors used for the mapping and saving of the dataset",
37+
required=True,
38+
type=int,
39+
)
40+
args = parser.parse_args()
41+
return args
42+
43+
44+
def get_hash(example):
45+
"""Get hash of content field."""
46+
return {"hash": hash(example["text"].replace(" ", ""))}
47+
48+
49+
def check_uniques(example, uniques):
50+
"""Check if current hash is still in set of unique hashes and remove if true."""
51+
if example["hash"] in uniques:
52+
uniques.remove(example["hash"])
53+
return True
54+
else:
55+
return False
56+
57+
58+
def preprocess(example):
59+
"""Chain all preprocessing steps into one function to not fill cache."""
60+
results = dict()
61+
results.update(get_hash(example))
62+
return results
63+
64+
65+
def filter(example, uniques, args):
66+
"""Filter dataset with heuristics."""
67+
if not check_uniques(example, uniques):
68+
return False
69+
else:
70+
return True
71+
72+
73+
def compress_file(file_path):
74+
"""Compress a file with g-zip."""
75+
with open(file_path, "rb") as f_in:
76+
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
77+
shutil.copyfileobj(f_in, f_out)
78+
os.unlink(file_path)
79+
80+
81+
def main():
82+
logging.basicConfig(
83+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
84+
datefmt="%m/%d/%Y %H:%M:%S",
85+
level=logging.INFO,
86+
)
87+
args = get_args()
88+
89+
# Load dataset
90+
t_start = time.time()
91+
ds = load_from_disk(args.dataset_dir)
92+
logger.info(f"Time to load dataset: {time.time()-t_start:.2f}")
93+
94+
# Run preprocessing
95+
t_start = time.time()
96+
ds = ds.map(preprocess, num_proc=args.num_proc)
97+
logger.info(f"Time to preprocess dataset: {time.time()-t_start:.2f}")
98+
99+
# Deduplicate hashes
100+
uniques = set(ds.unique("hash"))
101+
frac = len(uniques) / len(ds)
102+
logger.info(f"Fraction of duplicates: {1-frac:.2%}")
103+
104+
# Deduplicate data and apply heuristics
105+
t_start = time.time()
106+
ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
107+
logger.info(f"Time to filter dataset: {time.time()-t_start:.2f}")
108+
logger.info(f"Size of filtered dataset: {len(ds_filter)}")
109+
110+
# Save data
111+
t_start = time.time()
112+
ds_filter.save_to_disk(args.save_dir)
113+
114+
logger.info(f"Time to save dataset: {time.time()-t_start:.2f}")
115+
116+
117+
if __name__ == "__main__":
118+
main()
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import json
2+
import shutil
3+
from collections import defaultdict
4+
import os
5+
import argparse
6+
import logging
7+
8+
import datasets
9+
from functools import partial
10+
import pandas as pd
11+
from datasets import Features, load_dataset, load_from_disk
12+
from tqdm import tqdm
13+
from datasets.utils.logging import set_verbosity_info
14+
from numpy.random import SeedSequence, default_rng
15+
16+
"""
17+
Cleaning text:
18+
- run exact deduplication
19+
"""
20+
21+
set_verbosity_info()
22+
logger = logging.getLogger(__name__)
23+
24+
###
25+
# seed processing and upload functions
26+
###
27+
28+
29+
META_COLUMNS = ["meta"]
30+
31+
# filter text to remove certain lines (e.g. menu items, copyright notice)
32+
def filter_lines(article, skip_set, used_lines):
33+
# TODO discuss the strip
34+
lines = [line.strip() for line in article.split("\n")]
35+
keep = []
36+
skip = []
37+
for line in lines:
38+
if line in skip_set and line in used_lines:
39+
skip += [line]
40+
elif line in skip_set:
41+
keep += [line]
42+
used_lines.add(line)
43+
else:
44+
keep += [line]
45+
return "\n".join(keep).strip(), "\n".join(skip).strip()
46+
47+
48+
def filter_lines_by_batch(texts, skip_set, used_lines, preserve_code, metadata=None):
49+
if preserve_code:
50+
filtered_lines = [
51+
filter_lines(article, skip_set, used_lines)
52+
if "lm_code" in eval(metadata_item)["source_dataset"]
53+
else (article, "")
54+
for article, metadata_item in zip(texts, metadata)
55+
]
56+
else:
57+
filtered_lines = [
58+
filter_lines(article, skip_set, used_lines) for article in texts
59+
]
60+
return tuple(zip(*filtered_lines))
61+
62+
63+
# do both together and return an entry
64+
def process_batch(batch, skip_set, used_lines, args):
65+
if not args.with_meta_col:
66+
texts, _ = filter_lines_by_batch(
67+
batch["text"], skip_set, used_lines, preserve_code=False
68+
)
69+
return {
70+
"text": texts,
71+
}
72+
else:
73+
texts, _ = filter_lines_by_batch(
74+
batch["text"],
75+
skip_set,
76+
used_lines,
77+
preserve_code=args.preserve_code,
78+
metadata=batch["meta"],
79+
)
80+
return {
81+
"meta": batch["meta"],
82+
"text": texts,
83+
}
84+
85+
86+
# looks at up to the first 10K pages for a seed and
87+
# records lines that appear in at least 1% of the unique pages
88+
def get_lines_to_skip(dset, n_records, pourcentage_threshold, min_repetition_threshold):
89+
line_counts = defaultdict(lambda: 0)
90+
seen_pages = defaultdict(lambda: 0)
91+
92+
seed = SeedSequence(42)
93+
rng = default_rng(seed)
94+
num_elements = min(len(dset), n_records)
95+
indices = rng.choice(len(dset), size=num_elements, replace=False, shuffle=False)
96+
97+
dset_sample = dset.select(indices)
98+
for page in tqdm(dset_sample):
99+
article = page["text"]
100+
101+
seen_pages[article] += 1
102+
# We count the number of times we see identical lines in different documents.
103+
all_lines = {line.strip() for line in article.split("\n")}
104+
for line in all_lines:
105+
line_counts[line] += 1
106+
107+
# TODO understand this logic, why it's not len(line_counts)
108+
if pourcentage_threshold is not None:
109+
thres_skip = max(
110+
min_repetition_threshold, len(seen_pages) * pourcentage_threshold
111+
)
112+
else:
113+
thres_skip = min_repetition_threshold
114+
skip_set = {line for line, ct in line_counts.items() if ct > thres_skip}
115+
return skip_set, seen_pages
116+
117+
118+
def clean_examples(examples, skip_lines_set, used_lines, args):
119+
if args.with_meta_col:
120+
results = {"text": [], "meta": []}
121+
else:
122+
results = {"text": []}
123+
# Collapses meta and cleans text
124+
preprocessed_batch = process_batch(examples, skip_lines_set, used_lines, args)
125+
assert set(results.keys()) == set(preprocessed_batch.keys())
126+
127+
for idx, cleaned_article in enumerate(preprocessed_batch["text"]):
128+
if len(cleaned_article) <= args.min_chars:
129+
continue
130+
for key in results.keys():
131+
results[key].append(preprocessed_batch[key][idx])
132+
133+
return results
134+
135+
136+
# create a private repository and push processed seed in jsonl format
137+
TEXT_COLUMN = "text"
138+
139+
140+
def filter_and_save(dset, skip_lines_set, seen_pages, args):
141+
repo_name = args.save_dir
142+
# TODO build a caching mechanism
143+
repo_name_tmp = f"{repo_name}.tmp"
144+
if not os.path.isdir(repo_name_tmp):
145+
os.makedirs(repo_name_tmp)
146+
147+
# process
148+
used_lines = set()
149+
dset = dset.map(
150+
partial(
151+
clean_examples,
152+
skip_lines_set=skip_lines_set,
153+
used_lines=used_lines,
154+
args=args,
155+
),
156+
batched=True,
157+
# num_proc=args.num_proc, # single proccess for used_lines
158+
batch_size=args.batch_size,
159+
remove_columns=dset.column_names,
160+
)
161+
logger.info(f"Finished cleaning")
162+
163+
# write to folder
164+
dset.save_to_disk(repo_name_tmp)
165+
166+
logger.info(f"Ended successfully, saved at {repo_name_tmp}")
167+
168+
# Saving skipped lines that are considered repetitive
169+
with open(os.path.join(repo_name_tmp, "skipped_lines.json"), "w") as fi:
170+
json.dump(list(skip_lines_set), fi, indent=2)
171+
172+
# Saving num of duplicated documents
173+
with open(os.path.join(repo_name_tmp, "duplicate_documents.json"), "w") as fi:
174+
json.dump([num for num in list(seen_pages.values()) if num > 1], fi, indent=2)
175+
176+
# Move so that the state becomes completed
177+
shutil.move(repo_name_tmp, repo_name)
178+
179+
180+
def text_is_not_none(batch):
181+
return [text is not None for text in batch["text"]]
182+
183+
184+
###
185+
# combine everything
186+
###
187+
def main():
188+
logging.basicConfig(
189+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
190+
datefmt="%m/%d/%Y %H:%M:%S",
191+
level=logging.INFO,
192+
)
193+
parser = argparse.ArgumentParser()
194+
parser.add_argument(
195+
"--save-dir", required=True, type=str, help="Where to save the datasets."
196+
)
197+
parser.add_argument(
198+
"--dataset_dir",
199+
help="path to where the arrow dataset is located",
200+
required=True,
201+
type=str,
202+
)
203+
parser.add_argument(
204+
"--batch-size",
205+
help="Batch size used for mapping the dataset",
206+
required=True,
207+
type=int,
208+
)
209+
parser.add_argument(
210+
"--num-proc",
211+
help="Number of processors used for the mapping of the dataset",
212+
required=True,
213+
type=int,
214+
)
215+
parser.add_argument(
216+
"--min-chars",
217+
help="Minimum number of chars in a line",
218+
required=True,
219+
type=int,
220+
)
221+
parser.add_argument(
222+
"--n-records",
223+
help="Number of records used to compute the repetitions",
224+
required=True,
225+
type=int,
226+
)
227+
parser.add_argument(
228+
"--pourcentage-threshold",
229+
help="Threshold used for filter repetitions",
230+
default=None,
231+
type=float,
232+
)
233+
parser.add_argument(
234+
"--min-repetition-threshold",
235+
help="Minimum threshold used for filter repetitions. Used when the number of available records is not enough",
236+
required=True,
237+
type=int,
238+
)
239+
parser.add_argument(
240+
"--with-meta-col",
241+
help="If the initial dataset has a meta column",
242+
action="store_true",
243+
)
244+
parser.add_argument(
245+
"--preserve_code",
246+
help="Exclude code datasets from the line dedup",
247+
action="store_true",
248+
)
249+
args = parser.parse_args()
250+
# Load dataset (data first needs to be git pulled, see above)
251+
252+
dset = load_from_disk(args.dataset_dir)
253+
254+
# pre-remove unecessary columns, hopefully that saves qui a bit of memory usage
255+
columns_to_keep = [TEXT_COLUMN] + META_COLUMNS
256+
dset = dset.remove_columns(list(set(dset.column_names) - set(columns_to_keep)))
257+
258+
# Filter None text columns
259+
number_of_samples_before = len(dset)
260+
dset = dset.filter(text_is_not_none, batched=True, num_proc=args.num_proc)
261+
number_of_samples_after_filtering_none = len(dset)
262+
logger.info(
263+
f"Filtered out {number_of_samples_before - number_of_samples_after_filtering_none} / {number_of_samples_before}"
264+
)
265+
266+
skip_lines_set, seen_pages = get_lines_to_skip(
267+
dset,
268+
n_records=args.n_records,
269+
pourcentage_threshold=args.pourcentage_threshold,
270+
min_repetition_threshold=args.min_repetition_threshold,
271+
)
272+
273+
filter_and_save(
274+
dset, skip_lines_set=skip_lines_set, seen_pages=seen_pages, args=args
275+
)
276+
logger.info("Finished")
277+
278+
279+
if __name__ == "__main__":
280+
main()

0 commit comments

Comments
 (0)