77one-shot calibration workflows.
88"""
99
10+ import math
1011import multiprocessing
1112import re
1213from typing import Any , Callable
1516from datasets import Dataset
1617from loguru import logger
1718from torch .utils .data import DataLoader , RandomSampler , SequentialSampler
18- from transformers .data import default_data_collator
19+ from transformers .data import DataCollatorWithPadding
1920
2021from llmcompressor .args import DatasetArguments
2122from llmcompressor .transformers .data import TextGenerationDataset
@@ -115,44 +116,56 @@ def get_calibration_dataloader(
115116 )
116117
117118 calibration_dataset = datasets .get ("calibration" )
119+ tokenizer = getattr (processor , "tokenizer" , processor )
120+ collate_fn = dataset_args .data_collator or DataCollatorWithPadding (tokenizer )
121+ if dataset_args .batch_size > 1 and (
122+ tokenizer .pad_token is None or tokenizer .pad_token_id < 0
123+ ):
124+ logger .warning ("Could not find padding token. Setting PAD token to EOS token" )
125+ tokenizer .pad_token = tokenizer .eos_token
118126
119127 return format_calibration_data (
120128 tokenized_dataset = calibration_dataset ,
129+ collate_fn = collate_fn ,
130+ batch_size = dataset_args .batch_size ,
121131 num_calibration_samples = dataset_args .num_calibration_samples ,
122132 do_shuffle = dataset_args .shuffle_calibration_samples ,
123- collate_fn = dataset_args .data_collator ,
124133 )
125134
126135
127136def format_calibration_data (
128137 tokenized_dataset : Dataset ,
138+ collate_fn : Callable ,
139+ batch_size : int = 1 ,
129140 num_calibration_samples : int | None = None ,
130141 do_shuffle : bool = True ,
131- collate_fn : Callable = default_data_collator ,
132142) -> list [torch .Tensor ]:
133143 """
134144 Creates a dataloader out of the calibration dataset split, trimming it to
135145 the desired number of calibration samples
136146 :param tokenized_dataset: dataset to convert to dataloader
137- :param num_calibration_samples: number of data samples to convert
147+ :param num_calibration_samples: number of batches to convert
138148 :param do_shuffle: whether to shuffle the dataset before selecting calibration
139149 samples, true by default
140150 :param collate_fn: optional custom collate function, or use default
141151 :return: list of trimmed calibration data tensors
142152 """
143- safe_calibration_samples = len (tokenized_dataset )
153+ # (1) shuffle dataset
154+ if do_shuffle :
155+ tokenized_dataset = tokenized_dataset .shuffle ()
156+
157+ # (2) truncate dataset
144158 if num_calibration_samples is not None :
145- safe_calibration_samples = min ( len ( tokenized_dataset ), num_calibration_samples )
146- if safe_calibration_samples != num_calibration_samples :
159+ num_batches = math . ceil ( num_calibration_samples / batch_size )
160+ if num_batches > len ( tokenized_dataset ) :
147161 logger .warning (
148- f"Requested { num_calibration_samples } calibration samples but "
149- f"the provided dataset only has { safe_calibration_samples } . "
162+ f"Requested { num_calibration_samples } calibration samples but the "
163+ f"provided dataset only has { len ( tokenized_dataset ) * batch_size } . "
150164 )
165+ num_batches = len (tokenized_dataset )
166+ tokenized_calibration = tokenized_dataset .select (num_batches )
151167
152- if do_shuffle :
153- tokenized_dataset = tokenized_dataset .shuffle ()
154- tokenized_calibration = tokenized_dataset .select (range (safe_calibration_samples ))
155-
168+ # (3) infer number of workers
156169 MAX_DATALOADER_WORKERS = 8
157170 try :
158171 num_workers = min (MAX_DATALOADER_WORKERS , multiprocessing .cpu_count () // 2 )
@@ -161,19 +174,18 @@ def format_calibration_data(
161174 "Could not determine number of CPUs, defaulting to 0 dataloader workers."
162175 )
163176 num_workers = 0
177+
178+ # (4) create dataloader
164179 dataloader_params = {
165- "batch_size" : 1 ,
180+ "batch_size" : batch_size ,
166181 "sampler" : RandomSampler (tokenized_calibration )
167182 if do_shuffle
168183 else SequentialSampler (tokenized_calibration ),
169184 "collate_fn" : collate_fn ,
170185 "pin_memory" : True ,
171186 "num_workers" : num_workers ,
172187 }
173-
174- calibration_dataloader = DataLoader (tokenized_calibration , ** dataloader_params )
175-
176- return calibration_dataloader
188+ return DataLoader (tokenized_calibration , ** dataloader_params )
177189
178190
179191def make_dataset_splits (
0 commit comments