Skip to content

Commit 63c175b

Browse files
Add Intel AutoRound algorithm support (#1994)
Resolve #1968 ### Highlights - Introduced `AutoRoundModifier` to enable AutoRound quantization for `wNa16`. - Added an end-to-end example and unit tests. - Verified functionality with local accuracy tests (GSM8K with a limit of 1000, the results may fluctuate due to non-determinism.) ```bash - LLMC-AutoRound vllm (pretrained=/storage/yiliu7/Meta-Llama-3-8B-Instruct-W4A16-G128-disbale-shuffule,tensor_parallel_size=1,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False), gen_kwargs: (None), limit: 1000.0, num_fewshot: None, batch_size: 128 |Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr| |-----|------:|----------------|-----:|-----------|---|----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.737|± |0.0139| | | |strict-match | 5|exact_match|↑ |0.736|± |0.0139| - AutoRound result as ref vllm (pretrained=/storage/yiliu7/meta-llama/Meta-Llama-3-8B-Instruct-ar/Meta-Llama-3-8B-Instruct-w4g128/,tensor_parallel_size=1,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False), gen_kwargs: (None), limit: 1000.0, num_fewshot: None, batch_size: 128 |Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr| |-----|------:|----------------|-----:|-----------|---|----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.739|± |0.0139| | | |strict-match | 5|exact_match|↑ |0.740|± |0.0139| ``` Attached [eval cmd](https://gist.github.com/yiliu30/a7881cd1cbf0d676e3ffac3e3833aa8e) FYI. ### Next stage (in later PRs) - [ ] Extend support for additional data types. - [ ] Add group-wise quantization recipes mapping between LLMC and AutoRound. - [ ] Add end-to-end tests. cc @hshen14 @thuang6 @wenhuach21 --------- Signed-off-by: yiliu30 <yi4.liu@intel.com> Signed-off-by: Yi Liu <yi4.liu@intel.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent c600e2e commit 63c175b

File tree

8 files changed

+508
-3
lines changed

8 files changed

+508
-3
lines changed

.github/workflows/test-check-transformers.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ jobs:
9797
if: (success() || failure()) && steps.install.outcome == 'success'
9898
run: |
9999
pytest -v tests/llmcompressor/transformers/gptq
100+
- name: Running AutoRound Tests
101+
if: (success() || failure()) && steps.install.outcome == 'success'
102+
run: |
103+
pytest -v tests/llmcompressor/transformers/autoround
100104
- name: Running ONESHOT Tests
101105
if: (success() || failure()) && steps.install.outcome == 'success'
102106
run: |
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from auto_round.calib_dataset import get_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.autoround import AutoRoundModifier
6+
from llmcompressor.utils import dispatch_for_generation
7+
8+
# Select model and load it.
9+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
10+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
11+
tokenizer = AutoTokenizer.from_pretrained(model_id)
12+
13+
# Select calibration dataset.
14+
NUM_CALIBRATION_SAMPLES = 128
15+
MAX_SEQUENCE_LENGTH = 2048
16+
# Get aligned calibration dataset.
17+
18+
ds = get_dataset(
19+
tokenizer=tokenizer,
20+
seqlen=MAX_SEQUENCE_LENGTH,
21+
nsamples=NUM_CALIBRATION_SAMPLES,
22+
)
23+
24+
25+
# Configure the quantization algorithm to run.
26+
# * quantize the weights to 4 bit with AutoRound with a group size 128
27+
recipe = AutoRoundModifier(
28+
targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200
29+
)
30+
31+
32+
# Apply algorithms.
33+
oneshot(
34+
model=model,
35+
dataset=ds,
36+
recipe=recipe,
37+
max_seq_length=MAX_SEQUENCE_LENGTH,
38+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
39+
# disable shuffling to get slightly better mmlu score
40+
shuffle_calibration_samples=False,
41+
)
42+
43+
# Confirm generations of the quantized model look sane.
44+
print("\n\n")
45+
print("========== SAMPLE GENERATION ==============")
46+
dispatch_for_generation(model)
47+
sample = tokenizer("Hello my name is", return_tensors="pt")
48+
sample = {key: value.to(model.device) for key, value in sample.items()}
49+
output = model.generate(**sample, max_new_tokens=100)
50+
print(tokenizer.decode(output[0]))
51+
print("==========================================\n\n")
52+
53+
# Save to disk compressed.
54+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound"
55+
model.save_pretrained(SAVE_DIR, save_compressed=True)
56+
tokenizer.save_pretrained(SAVE_DIR)

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def localversion_func(version: ScmVersion) -> str:
144144
if BUILD_TYPE == "release"
145145
else "compressed-tensors>=0.12.3a2"
146146
),
147+
# TODO: replace it with the release version
148+
("auto_round @ git+https://github.com/intel/auto-round.git@llmc"),
147149
],
148150
extras_require={
149151
"dev": [
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa
2+
3+
from .base import *

0 commit comments

Comments
 (0)