Skip to content

Commit f1def17

Browse files
WeiweiZhang1pre-commit-ci[bot]wenhuach21
authored
add export funcs for autoround and fix incorrect hype-parameters in example (#1536)
* add export funcs for autoround Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixtypo Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed incorrect default hyperparameters * fixed mixstral7b*8 issue * port export func to file Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add export Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixtypo Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change device type Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: wenhuach21 <108330088+wenhuach21@users.noreply.github.com>
1 parent f141aff commit f1def17

File tree

7 files changed

+616
-21
lines changed

7 files changed

+616
-21
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/autoround/main.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
import argparse
2-
3-
from neural_compressor.adaptor.torch_utils.autoround import AutoRound, AutoOPTRound, AutoAdamRound
2+
import sys
3+
from neural_compressor.adaptor.torch_utils.autoround import (AutoRound,
4+
AutoOPTRound,
5+
AutoAdamRound)
46

57
parser = argparse.ArgumentParser()
68
import torch
79
import os
10+
import re
11+
import json
812

913
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
1014
torch.use_deterministic_algorithms(True, warn_only=True)
1115
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
12-
1316
from transformers import set_seed
14-
1517
from eval import eval_model
1618

17-
import re
18-
1919
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2020

2121

@@ -44,13 +44,13 @@
4444
parser.add_argument("--sym", action='store_true',
4545
help=" sym quantization")
4646

47-
parser.add_argument("--iters", default=400, type=int,
47+
parser.add_argument("--iters", default=200, type=int,
4848
help=" iters")
4949

5050
parser.add_argument("--use_quant_input", action='store_true',
5151
help="whether to use the output of quantized block to tune the next block")
5252

53-
parser.add_argument("--lr", default=0.05, type=float,
53+
parser.add_argument("--lr", default=0.005, type=float,
5454
help="step size")
5555

5656
parser.add_argument("--minmax_lr", default=None, type=float,
@@ -83,6 +83,9 @@
8383

8484
parser.add_argument("--enable_minmax_tuning", action='store_true',
8585
help="whether enable weight minmax tuning")
86+
87+
parser.add_argument("--use_optimum_format", default=True,
88+
help="whether use HuggingFace format.")
8689

8790
# parser.add_argument("--tasks", default=["lambada_openai", "hellaswag", "winogrande", "piqa"],
8891
# help="lm-eval tasks")
@@ -186,9 +189,17 @@
186189

187190
optq = round(model, tokenizer, args.num_bits, args.group_size, scheme, bs=args.train_bs,
188191
seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr,
189-
minmax_lr=args.minmax_lr, use_quant_input=args.use_quant_input,
190-
amp=args.amp, n_samples=args.n_samples, low_gpu_mem_usage=args.low_gpu_mem_usage, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps) ##TODO args pass
191-
optq.quantize()
192+
use_quant_input=args.use_quant_input, amp=args.amp, n_samples=args.n_samples,
193+
low_gpu_mem_usage=args.low_gpu_mem_usage, minmax_lr=args.minmax_lr,
194+
seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps) ##TODO args pass
195+
q_model, q_config = optq.quantize()
196+
if args.use_optimum_format:
197+
output_dir = args.output_dir + "_" + args.model_name.split('/')[-1] + "/"
198+
if not os.path.exists(output_dir):
199+
os.makedirs(output_dir)
200+
q_config_path = os.path.join(output_dir, "qconfig.json")
201+
with open(q_config_path, "w") as f:
202+
json.dump(q_config, f, indent=4)
192203

193204
torch.cuda.empty_cache()
194205
model.eval()
@@ -202,3 +213,4 @@
202213
eval_model(output_dir=output_dir, model=model, tokenizer=tokenizer, tasks=args.tasks, \
203214
eval_bs=args.eval_bs, use_accelerate=args.low_gpu_mem_usage, device=cuda_device, excel_file=excel_name,
204215
limit=None)
216+

neural_compressor/adaptor/torch_utils/autoround/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from .autoround import AutoRound, AutoOPTRound, AutoAdamRound
15+
from .export import export_compressed_model

neural_compressor/adaptor/torch_utils/autoround/autoround.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,13 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", v=0, min_scal
142142
weight = weight.reshape(-1, group_size)
143143
if isinstance(v, torch.Tensor):
144144
v = v.reshape(-1, group_size)
145-
146145
weight, scale, zp = quant_weight_actor(
147146
weight, num_bits, scheme=scheme, v=v, min_scale=min_scale, max_scale=max_scale
148147
)
149148
weight = weight.reshape(orig_shape)
149+
scale = scale.reshape(orig_shape[0], -1) # TODO validating the feasibility on conv1d
150+
if zp is not None:
151+
zp = zp.reshape(orig_shape[0], -1)
150152
return weight, scale, zp
151153

152154
else:
@@ -160,11 +162,49 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", v=0, min_scal
160162
weight_new, num_bits, scheme=scheme, v=v, min_scale=min_scale, max_scale=max_scale
161163
)
162164
weight_new = weight_new.reshape(orig_shape[0], -1)
163-
165+
scale = scale.reshape(orig_shape[0], -1)
166+
if zp is not None:
167+
zp = zp.reshape(orig_shape[0], -1)
164168
weight_new = weight_new[:, :-pad_len]
169+
scale = scale[:, :-pad_len]
170+
zp = zp[:, :-pad_len]
165171
return weight_new, scale, zp
166172

167173

174+
def quant_weight_w_scale(weight, scale, zp, group_size=-1):
175+
"""Quant and dequant tensor with group size.
176+
177+
Args:
178+
weight: input weight
179+
scale: scale
180+
zp: zero point
181+
group_size (int, optional): how many elements share one scale/zp. Defaults to -1.
182+
183+
Returns:
184+
output: int weight.
185+
"""
186+
device = weight.device
187+
scale = scale.to(device)
188+
if zp is not None:
189+
zp = zp.to(device)
190+
if group_size == -1:
191+
return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp)
192+
int_weight = torch.zeros(weight.shape).to(device)
193+
leng = weight.shape[1] // group_size
194+
tail_flag = False if weight.shape[1] % group_size == 0 else True
195+
for i in range(leng):
196+
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1)
197+
if zp is not None:
198+
int_weight_tmp += zp[:, i].unsqueeze(1)
199+
int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp)
200+
if tail_flag:
201+
int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1)
202+
if zp is not None:
203+
int_weight_tmp += zp[:, -1].unsqueeze(1)
204+
int_weight[:, leng * group_size :] = torch.round(int_weight_tmp)
205+
return int_weight
206+
207+
168208
def round_ste(x: torch.Tensor):
169209
"""Straight-Through Estimator for rounding.
170210
This function is adapted from omniquant.
@@ -819,6 +859,7 @@ def get_block_names(model):
819859
for n, m in model.named_modules():
820860
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
821861
target_m = (n, m)
862+
break
822863
for n, m in target_m[1].named_children():
823864
block_names.append(target_m[0] + "." + n)
824865
return block_names
@@ -976,7 +1017,6 @@ def __init__(
9761017
self.amp = amp
9771018
self.use_quant_input = use_quant_input
9781019
self.enable_minmax_tuning = enable_minmax_tuning
979-
self.n_samples = n_samples
9801020
self.n_blocks = n_blocks
9811021
self.bits = bits
9821022
self.group_size = group_size
@@ -997,6 +1037,8 @@ def __init__(
9971037
self.tokenizer = tokenizer
9981038
self.seqlen = seqlen
9991039
self.train_bs = bs
1040+
self.n_samples = bs * (n_samples // bs)
1041+
assert self.n_samples > 0, f"Recommend setting an n_samples that is divisible by batch size{self.train_bs}"
10001042
self.n_blocks = n_blocks
10011043
self.device = device
10021044
self.amp_dtype = torch.float16
@@ -1393,20 +1435,29 @@ def quantize(self):
13931435
if n in self.weight_config.keys():
13941436
if hasattr(m, "scale"):
13951437
self.weight_config[n]["scale"] = m.scale
1438+
# self.weight_config[n]["scale_dtype"] = m.scale.dtype
13961439
self.weight_config[n]["zp"] = m.zp
1440+
# self.weight_config[n]["zp_dtype"] = m.zp.dtype
13971441
delattr(m, "scale")
13981442
delattr(m, "zp")
13991443
else:
14001444
self.weight_config[n]["data_type"] = "float"
1401-
if self.amp_dtype == torch.bfloat16:
1402-
self.weight_config[n]["data_type"] = "bfloat"
1403-
self.weight_config[n]["bits"] = 16
1445+
self.weight_config[n]["bits"] = 32
1446+
if self.amp:
1447+
self.weight_config[n]["bits"] = 16
1448+
if self.amp_dtype == torch.bfloat16:
1449+
self.weight_config[n]["data_type"] = "bfloat"
14041450
self.weight_config[n]["group_size"] = None
14051451
self.weight_config[n]["sym"] = None
14061452

1453+
for k, v in self.weight_config.items():
1454+
for m, n in v.items():
1455+
if isinstance(n, torch.Tensor):
1456+
self.weight_config[k][m] = n.tolist()
14071457
end_time = time.time()
14081458
cost_time = end_time - start_time
14091459
logger.info(f"quantization runtime {cost_time}")
1460+
14101461
return self.model, self.weight_config
14111462

14121463

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import json
17+
from typing import Union
18+
19+
try:
20+
from neural_compressor.utils.utility import LazyImport
21+
22+
torch = LazyImport("torch")
23+
from neural_compressor.utils import logger
24+
except: # pragma: no cover
25+
import logging
26+
27+
import torch
28+
29+
logger = logging.getLogger()
30+
31+
32+
def export_compressed_model(
33+
model,
34+
weight_config: Union[str, dict],
35+
enable_full_range=False,
36+
compression_dtype=torch.int32,
37+
compression_dim=1,
38+
scale_dtype=torch.float32,
39+
device="cpu",
40+
use_optimum_format=True,
41+
):
42+
"""Convert Linear to WeightOnlyLinear for low memory inference.
43+
44+
Args:
45+
weight_config (str|dict): qconfig dict or Path of qconfig.json.
46+
enable_full_range (bool, optional): Whether to leverage the full compression range
47+
under symmetric quantization. Defaults to False.
48+
compression_dtype (torch.Tensor, optional): The target dtype after comoression.
49+
Defaults to torch.int32.
50+
compression_dim (int, optional): Select from [0, 1], 0 is output channel,
51+
1 is input channel. Defaults to 1.
52+
scale_dtype (torch.Tensor, optional): Use float32 or float16.
53+
Defaults to torch.float32.
54+
device (str, optional): choose device for compression. Defaults to cpu.
55+
use_optimum_format (bool, optional): use the popular huggingface compression format.
56+
1: compression_dim: weight = 1, zeros = 0 and both are transposed.
57+
2: zeros -= 1 before compression. Why we need it?
58+
3: g_idx: use same number for one group instead of recording the channel order.
59+
4. parameter name changed, such as 'packed_weight' -> 'qweight'.
60+
5. zeros is always needed even for sym.
61+
"""
62+
from .autoround import get_module, quant_weight_w_scale, set_module
63+
from .model_wrapper import WeightOnlyLinear
64+
65+
compressed_model = copy.deepcopy(model)
66+
if isinstance(weight_config, str):
67+
with open(weight_config, "r") as f:
68+
q_config = json.load(f)
69+
else:
70+
q_config = weight_config
71+
for k, v in q_config.items():
72+
logger.info(f"Compressing {k} on device {device}")
73+
if v["data_type"] == "float":
74+
continue
75+
else:
76+
dtype = v["data_type"]
77+
num_bits = v["bits"]
78+
group_size = v["group_size"]
79+
scheme = v["scheme"]
80+
m = get_module(compressed_model, k)
81+
fp_weight = m.weight.data
82+
scale = torch.tensor(v["scale"], dtype=torch.float32) # may exist dtype dismatch problem
83+
zp = None if scheme == "sym" else torch.tensor(v["zp"], dtype=torch.int32)
84+
int_weight = quant_weight_w_scale(fp_weight, scale, zp, group_size)
85+
int_weight = int_weight.type(torch.int32)
86+
new_module = WeightOnlyLinear(
87+
m.in_features,
88+
m.out_features,
89+
num_bits,
90+
group_size,
91+
dtype=dtype,
92+
zp=zp is not None,
93+
bias=m.bias is not None,
94+
device=device,
95+
use_optimum_format=True,
96+
)
97+
new_module.pack(int_weight, scale, zp, m.bias)
98+
set_module(compressed_model, k, new_module)
99+
return compressed_model

0 commit comments

Comments
 (0)