Skip to content

Commit 35a0507

Browse files
committed
add batch size args
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 34814c7 commit 35a0507

File tree

6 files changed

+90
-47
lines changed

6 files changed

+90
-47
lines changed

examples/multimodal_vision/gemma3_example.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import requests
2-
import torch
32
from PIL import Image
4-
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
3+
from transformers import (
4+
AutoProcessor,
5+
DataCollatorWithPadding,
6+
Gemma3ForConditionalGeneration,
7+
)
58

69
from llmcompressor import oneshot
710
from llmcompressor.modifiers.quantization import GPTQModifier
@@ -11,18 +14,21 @@
1114
model_id = "google/gemma-3-4b-it"
1215
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
1316
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
17+
collator = DataCollatorWithPadding(processor.tokenizer)
1418

1519
# Oneshot arguments
16-
DATASET_ID = "flickr30k"
17-
DATASET_SPLIT = {"calibration": "test[:512]"}
1820
NUM_CALIBRATION_SAMPLES = 512
1921
MAX_SEQUENCE_LENGTH = 2048
22+
BATCH_SIZE = 512
23+
DATASET_ID = "flickr30k"
24+
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}
2025

2126

22-
# Define a oneshot data collator for multimodal inputs.
23-
def data_collator(batch):
24-
assert len(batch) == 1
25-
return {key: torch.tensor(value) for key, value in batch[0].items()}
27+
# Define a oneshot data collator for multimodal processors
28+
# remove extra dim added by vision processor
29+
def data_collator(features: list[dict[str, object]]):
30+
features = [{key: feature[key][0] for key in feature} for feature in features]
31+
return collator(features)
2632

2733

2834
# Recipe
@@ -45,10 +51,11 @@ def data_collator(batch):
4551
dataset=DATASET_ID,
4652
splits=DATASET_SPLIT,
4753
recipe=recipe,
54+
batch_size=BATCH_SIZE,
4855
max_seq_length=MAX_SEQUENCE_LENGTH,
4956
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
50-
trust_remote_code_model=True,
5157
data_collator=data_collator,
58+
trust_remote_code_model=True,
5259
)
5360

5461
# Confirm generations of the quantized model look sane.

examples/multimodal_vision/idefics3_example.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import requests
2-
import torch
32
from datasets import load_dataset
43
from PIL import Image
5-
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
4+
from transformers import (
5+
AutoProcessor,
6+
DataCollatorWithPadding,
7+
Idefics3ForConditionalGeneration,
8+
)
69

710
from llmcompressor import oneshot
811
from llmcompressor.modifiers.quantization import GPTQModifier
@@ -12,18 +15,21 @@
1215
model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
1316
model = Idefics3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
1417
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
18+
collator = DataCollatorWithPadding(processor.tokenizer)
1519

1620
# Oneshot arguments
17-
DATASET_ID = "lmms-lab/flickr30k"
18-
DATASET_SPLIT = "test[:512]"
1921
NUM_CALIBRATION_SAMPLES = 512
20-
MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here
22+
MAX_SEQUENCE_LENGTH = 4096
23+
BATCH_SIZE = 512
24+
DATASET_ID = "lmms-lab/flickr30k"
25+
DATASET_SPLIT = f"test[:{NUM_CALIBRATION_SAMPLES}]"
2126

2227

23-
# Define a oneshot data collator for multimodal inputs.
24-
def data_collator(batch):
25-
assert len(batch) == 1
26-
return {key: torch.tensor(value) for key, value in batch[0].items()}
28+
# Define a oneshot data collator for multimodal processors
29+
# remove extra dim added by vision processor
30+
def data_collator(features: list[dict[str, object]]):
31+
features = [{key: feature[key][0] for key in feature} for feature in features]
32+
return collator(features)
2733

2834

2935
# Recipe
@@ -86,8 +92,8 @@ def tokenize(sample):
8692
model=model,
8793
dataset=ds,
8894
recipe=recipe,
95+
batch_size=BATCH_SIZE,
8996
max_seq_length=MAX_SEQUENCE_LENGTH,
90-
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
9197
trust_remote_code_model=True,
9298
data_collator=data_collator,
9399
sequential_targets=["LlamaDecoderLayer"],

examples/quantization_w4a16/llama3_example.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from llmcompressor.utils import dispatch_for_generation
77

88
# Select model and load it.
9-
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
9+
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
10+
model_id = "meta-llama/Llama-3.2-1B-Instruct"
1011
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
1112
tokenizer = AutoTokenizer.from_pretrained(model_id)
1213

@@ -16,8 +17,9 @@
1617

1718
# Select number of samples. 512 samples is a good place to start.
1819
# Increasing the number of samples can improve accuracy.
19-
NUM_CALIBRATION_SAMPLES = 512
20+
NUM_CALIBRATION_SAMPLES = 16
2021
MAX_SEQUENCE_LENGTH = 2048
22+
BATCH_SIZE = 4
2123

2224
# Load dataset and preprocess.
2325
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
@@ -58,8 +60,8 @@ def tokenize(sample):
5860
model=model,
5961
dataset=ds,
6062
recipe=recipe,
63+
batch_size=BATCH_SIZE,
6164
max_seq_length=MAX_SEQUENCE_LENGTH,
62-
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
6365
)
6466

6567
# Confirm generations of the quantized model look sane.

src/llmcompressor/args/dataset_arguments.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass, field
11-
from typing import Any, Callable
12-
13-
from transformers import DefaultDataCollator
11+
from typing import Callable, Optional
1412

1513

1614
@dataclass
@@ -69,9 +67,25 @@ class CustomDatasetArguments(DVCDatasetArguments):
6967
},
7068
)
7169

72-
data_collator: Callable[[Any], Any] = field(
73-
default_factory=lambda: DefaultDataCollator(),
74-
metadata={"help": "The function to used to form a batch from the dataset"},
70+
data_collator: Optional[Callable] = field(
71+
default=None,
72+
metadata={
73+
"help": (
74+
"The function to used to form a batch from the dataset. Defaults to "
75+
"`DataCollatorWithPadding(processor)`."
76+
)
77+
},
78+
)
79+
80+
batch_size: int = field(
81+
default=1,
82+
metadata={
83+
"help": (
84+
"Calibration batch size. During calibration, LLM Compressor disables "
85+
"lm_head output computations to reduce memory usage from large "
86+
"calibration matches"
87+
)
88+
},
7589
)
7690

7791

src/llmcompressor/datasets/utils.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
one-shot calibration workflows.
88
"""
99

10+
import math
1011
import multiprocessing
1112
import re
1213
from typing import Any, Callable
@@ -15,7 +16,7 @@
1516
from datasets import Dataset
1617
from loguru import logger
1718
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
18-
from transformers.data import default_data_collator
19+
from transformers.data import DataCollatorWithPadding
1920

2021
from llmcompressor.args import DatasetArguments
2122
from llmcompressor.transformers.data import TextGenerationDataset
@@ -115,44 +116,56 @@ def get_calibration_dataloader(
115116
)
116117

117118
calibration_dataset = datasets.get("calibration")
119+
tokenizer = getattr(processor, "tokenizer", processor)
120+
collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer)
121+
if dataset_args.batch_size > 1 and (
122+
tokenizer.pad_token is None or tokenizer.pad_token_id < 0
123+
):
124+
logger.warning("Could not find padding token. Setting PAD token to EOS token")
125+
tokenizer.pad_token = tokenizer.eos_token
118126

119127
return format_calibration_data(
120128
tokenized_dataset=calibration_dataset,
129+
collate_fn=collate_fn,
130+
batch_size=dataset_args.batch_size,
121131
num_calibration_samples=dataset_args.num_calibration_samples,
122132
do_shuffle=dataset_args.shuffle_calibration_samples,
123-
collate_fn=dataset_args.data_collator,
124133
)
125134

126135

127136
def format_calibration_data(
128137
tokenized_dataset: Dataset,
138+
collate_fn: Callable,
139+
batch_size: int = 1,
129140
num_calibration_samples: int | None = None,
130141
do_shuffle: bool = True,
131-
collate_fn: Callable = default_data_collator,
132142
) -> list[torch.Tensor]:
133143
"""
134144
Creates a dataloader out of the calibration dataset split, trimming it to
135145
the desired number of calibration samples
136146
:param tokenized_dataset: dataset to convert to dataloader
137-
:param num_calibration_samples: number of data samples to convert
147+
:param num_calibration_samples: number of batches to convert
138148
:param do_shuffle: whether to shuffle the dataset before selecting calibration
139149
samples, true by default
140150
:param collate_fn: optional custom collate function, or use default
141151
:return: list of trimmed calibration data tensors
142152
"""
143-
safe_calibration_samples = len(tokenized_dataset)
153+
# (1) shuffle dataset
154+
if do_shuffle:
155+
tokenized_dataset = tokenized_dataset.shuffle()
156+
157+
# (2) truncate dataset
144158
if num_calibration_samples is not None:
145-
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
146-
if safe_calibration_samples != num_calibration_samples:
159+
num_batches = math.ceil(num_calibration_samples / batch_size)
160+
if num_batches > len(tokenized_dataset):
147161
logger.warning(
148-
f"Requested {num_calibration_samples} calibration samples but "
149-
f"the provided dataset only has {safe_calibration_samples}. "
162+
f"Requested {num_calibration_samples} calibration samples but the "
163+
f"provided dataset only has {len(tokenized_dataset) * batch_size}. "
150164
)
165+
num_batches = len(tokenized_dataset)
166+
tokenized_calibration = tokenized_dataset.select(num_batches)
151167

152-
if do_shuffle:
153-
tokenized_dataset = tokenized_dataset.shuffle()
154-
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))
155-
168+
# (3) infer number of workers
156169
MAX_DATALOADER_WORKERS = 8
157170
try:
158171
num_workers = min(MAX_DATALOADER_WORKERS, multiprocessing.cpu_count() // 2)
@@ -161,19 +174,18 @@ def format_calibration_data(
161174
"Could not determine number of CPUs, defaulting to 0 dataloader workers."
162175
)
163176
num_workers = 0
177+
178+
# (4) create dataloader
164179
dataloader_params = {
165-
"batch_size": 1,
180+
"batch_size": batch_size,
166181
"sampler": RandomSampler(tokenized_calibration)
167182
if do_shuffle
168183
else SequentialSampler(tokenized_calibration),
169184
"collate_fn": collate_fn,
170185
"pin_memory": True,
171186
"num_workers": num_workers,
172187
}
173-
174-
calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params)
175-
176-
return calibration_dataloader
188+
return DataLoader(tokenized_calibration, **dataloader_params)
177189

178190

179191
def make_dataset_splits(

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
from datetime import datetime
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Callable, Optional
1616

1717
from loguru import logger
1818
from torch.utils.data import DataLoader
@@ -248,6 +248,8 @@ def oneshot(
248248
dataset_config_name: str | None = None,
249249
dataset_path: str | None = None,
250250
splits: str | list[str] | dict[str, str] | None = None,
251+
batch_size: int = 1,
252+
data_collator: Optional[Callable] = None,
251253
num_calibration_samples: int = 512,
252254
shuffle_calibration_samples: bool = True,
253255
max_seq_length: int = 384,

0 commit comments

Comments
 (0)