Skip to content

Commit 027ccf3

Browse files
committed
[feat] support mobilellm.
1 parent 35c3804 commit 027ccf3

File tree

1 file changed

+76
-97
lines changed

1 file changed

+76
-97
lines changed

llmexport/llmexport.py

Lines changed: 76 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def regist_llama(self):
9494
self.regist('llama', llama_map)
9595
self.regist('qwen2', llama_map)
9696
self.regist('internlm', llama_map)
97+
self.regist('mobilellm', llama_map)
9798
# baichuan
9899
baichuan_map = copy.deepcopy(self.default_map)
99100
baichuan_map[self.attention_key] = {
@@ -365,9 +366,7 @@ def __init__(
365366
self.tokenizer = model.tokenizer
366367
self.w_bit = model.quant_bit
367368
self.group_size = model.quant_block
368-
self.zeropoint = False
369-
# self.calib_data = model.calib_data
370-
# self.split = model.split
369+
self.zeropoint = not model.symmetric
371370
self.calib_data = 'ag_news'
372371
self.split = 'test'
373372
self.duo_scaling = True
@@ -390,42 +389,20 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):
390389
w = w.reshape(-1, self.group_size)
391390
assert w.dim() == 2
392391
assert torch.isnan(w).sum() == 0
393-
394392
# zero point quantization
395393
if self.zeropoint:
396-
'''
397-
max_val = w.amax(dim=1, keepdim=True)
398-
min_val = w.amin(dim=1, keepdim=True)
399-
max_int = 2**self.w_bit - 1
400-
min_int = 0
401-
scales = (max_val - min_val).clamp(min=1e-5) / max_int
402-
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
403-
w = (
404-
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
405-
) * scales
406-
zeros = zeros.view(org_w_shape[0], -1)
407-
'''
408-
#'''
409394
max_val = w.amax(dim=1, keepdim=True)
410395
min_val = w.amin(dim=1, keepdim=True)
411396
offset = 1 << (self.w_bit - 1)
412397
clip_max = offset - 1
413398
clip_min = -offset
414399
scales = (max_val - min_val) / (clip_max - clip_min)
400+
zeros = - torch.round(min_val / scales) + clip_min
401+
qw = torch.round(w / scales) + zeros
402+
qw = torch.clamp(qw, clip_min, clip_max)
403+
w = (qw - zeros) * scales
415404
zeros = min_val.view(org_w_shape[0], -1)
416-
qw = torch.clamp(torch.round((w - min_val) / scales) + clip_min, clip_min, clip_max)
417-
w = (qw - clip_min) * scales + min_val
418-
#'''
419405
else:
420-
'''
421-
max_val = w.abs().amax(dim=1, keepdim=True)
422-
max_val = max_val.clamp(min=1e-5)
423-
max_int = 2 ** (self.w_bit - 1) - 1
424-
min_int = -(2 ** (self.w_bit - 1))
425-
scales = max_val / max_int
426-
zeros = None
427-
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
428-
'''
429406
abs_max = w.abs().amax(dim=1, keepdim=True)
430407
offset = 1 << (self.w_bit - 1)
431408
clip_max = offset - 1
@@ -464,7 +441,7 @@ def quantize(self):
464441
].to(common_device)
465442

466443
self.inps = self.inps.to(common_device)
467-
print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
444+
# print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
468445

469446
# [STEP 1]: Get layer, extract linear modules, extract input features
470447
named_linears = AwqQuantizer.get_named_linears(self.modules[i])
@@ -522,7 +499,7 @@ def quantize(self):
522499
self._search_best_scale(self.modules[i], **layer)
523500
for layer in module_config
524501
]
525-
print(scales_list); exit(0)
502+
# print(scales_list); exit(0)
526503
AwqQuantizer.apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
527504
# [STEP 3]: Compute and apply clipping list
528505
if self.apply_clip:
@@ -1316,7 +1293,7 @@ def __init__(self, onnx_path, weight_ops, config):
13161293
self.quant_block = config.quant_block
13171294
self.quant_bit = config.quant_bit
13181295
self.lm_quant_bit = config.lm_quant_bit
1319-
self.zeropoint = config.zeropoint
1296+
self.symmetric = config.symmetric
13201297
self.mnn_weight_offset = 0
13211298
self.onnx_model_path = onnx_path
13221299
self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn')
@@ -1430,7 +1407,7 @@ def rebuild(self, json_path):
14301407
json.dump(mnn_graph, file, ensure_ascii=False, indent=4)
14311408
return self.mnn_weight_path
14321409

1433-
def quant(self, weight, quant_bit, quant_block, zeropoint):
1410+
def quant(self, weight, quant_bit, quant_block, symmetric):
14341411
weight = weight.numpy()
14351412
oc, ic = weight.shape
14361413
if quant_block == 0:
@@ -1441,27 +1418,32 @@ def quant(self, weight, quant_bit, quant_block, zeropoint):
14411418
weight = weight.reshape(oc, block_num, block_size)
14421419
offset = 1 << (quant_bit - 1)
14431420
clip_max = offset - 1
1444-
if zeropoint:
1445-
clip_min = -offset
1446-
max_val = np.max(weight, axis=-1, keepdims=True)
1447-
min_val = np.min(weight, axis=-1, keepdims=True)
1448-
scale = (max_val - min_val) / (clip_max - clip_min)
1449-
1450-
# q_weight = np.round((weight - min_val) / scale) + clip_min
1451-
q_weight = np.round(weight / scale) - np.round(min_val / scale) + clip_min
1452-
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
1453-
alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten()
1454-
else:
1421+
if symmetric:
14551422
clip_min = -clip_max
14561423
abs_max = np.max(np.abs(weight), axis=-1, keepdims=True)
14571424
scale = abs_max / clip_max
14581425
q_weight = np.round(weight / scale)
14591426
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
14601427
alpha = scale.flatten()
1428+
else:
1429+
clip_min = -offset
1430+
max_val = np.max(weight, axis=-1, keepdims=True)
1431+
min_val = np.min(weight, axis=-1, keepdims=True)
1432+
scale = (max_val - min_val) / (clip_max - clip_min)
1433+
1434+
if False:
1435+
q_weight = np.round((weight - min_val) / scale) + clip_min
1436+
zeros = min_val - scale * clip_min
1437+
else:
1438+
q_weight = np.round(weight / scale) - np.round(min_val / scale) + clip_min
1439+
zeros = (np.round(min_val / scale) - clip_min) * scale
1440+
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
1441+
alpha = np.stack([zeros.flatten(), scale.flatten()], axis=-1).flatten()
14611442
q_weight = q_weight.reshape(-1, 2)
14621443
if quant_bit == 4:
14631444
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
14641445

1446+
clip_min = 1
14651447
return q_weight, alpha, clip_min
14661448

14671449
def write_npy(self, data):
@@ -1483,12 +1465,18 @@ def write_header(self, ic, oc, quant_bit):
14831465
header_length = dim_num + dim_length + map_length
14841466
return header_length, shape_dtype == np.int32
14851467

1486-
def build_weight(self, linear, quant_bit, quant_block, zeropoint):
1468+
def build_weight(self, linear, quant_bit, quant_block, symmetric):
14871469
ic, oc = linear.in_features, linear.out_features
1488-
q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block, zeropoint)
1489-
header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
1490-
weight_len = self.write_npy(q_weight) + header_len
1491-
alpha_len = self.write_npy(alpha)
1470+
if quant_bit == 16:
1471+
half_weight = linear.weight.data.half().flatten().numpy()
1472+
weight_len = self.write_npy(half_weight)
1473+
alpha_len, q_min, shape_int32 = 0, 0, False
1474+
else:
1475+
assert(quant_bit in (4, 8))
1476+
q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block, symmetric)
1477+
header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
1478+
weight_len = self.write_npy(q_weight) + header_len
1479+
alpha_len = self.write_npy(alpha)
14921480
if linear.bias is not None:
14931481
bias = linear.bias.data.flatten().numpy()
14941482
bias_length = self.write_npy(bias)
@@ -1548,7 +1536,7 @@ def rebuild_linear(self, op, graph):
15481536

15491537
quant_bit = self.lm_quant_bit if 'lm_head' in name else self.quant_bit
15501538
block_size = ic if self.quant_block == 0 else self.quant_block
1551-
external, q_min, shape_int32 = self.build_weight(linear, quant_bit, self.quant_block, self.zeropoint)
1539+
external, q_min, shape_int32 = self.build_weight(linear, quant_bit, self.quant_block, self.symmetric)
15521540

15531541
origin_input = op['inputIndexes']
15541542
origin_output = op['outputIndexes']
@@ -1589,12 +1577,22 @@ def rebuild_linear(self, op, graph):
15891577
},
15901578
"defaultDimentionFormat": "NHWC"
15911579
}
1592-
if self.zeropoint:
1593-
aMin = q_min
1594-
readType = oc * (ic // block_size)
1580+
1581+
if quant_bit == 16:
1582+
quanParameter = { "type": 3 }
15951583
else:
1596-
aMin = 0
1597-
readType = 0
1584+
if self.symmetric:
1585+
aMin = 0
1586+
readType = 0
1587+
else:
1588+
aMin = q_min
1589+
readType = oc * (ic // block_size)
1590+
1591+
quanParameter = {
1592+
"quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
1593+
"useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
1594+
"type": 1, "aMax": 0, "aMin": aMin, "readType": readType, "weightSize": 0
1595+
}
15981596
conv_op = {
15991597
"name": conv_name,
16001598
"inputIndexes": pre_convert_output,
@@ -1608,11 +1606,7 @@ def rebuild_linear(self, op, graph):
16081606
'outputCount': oc, 'relu': False, 'padMode': 'CAFFE',
16091607
'relu6': False, 'inputCount': ic, 'hasOutputShape': False
16101608
},
1611-
"quanParameter": {
1612-
"quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
1613-
"useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
1614-
"type": 1, "aMax": 0, "aMin": aMin, "readType": readType, "weightSize": 0
1615-
},
1609+
"quanParameter": quanParameter,
16161610
"external": external
16171611
},
16181612
"defaultDimentionFormat": "NHWC"
@@ -1938,10 +1932,10 @@ def __init__(self, lm_, final_layernorm_, config):
19381932
self.final_layernorm = final_layernorm_
19391933
self.lm = lm_
19401934
self.hidden_size = config.hidden_size
1941-
self.all_logits = config.all_logits
1935+
self.ppl = config.ppl
19421936

19431937
def forward(self, hidden_states):
1944-
if not self.all_logits:
1938+
if not self.ppl:
19451939
# just need last logit for predict next token
19461940
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
19471941
hidden_states = self.final_layernorm(hidden_states)
@@ -2233,10 +2227,11 @@ def init_from_args(self, args):
22332227
self.tokenizer_path = args.tokenizer_path
22342228
self.lora_path = args.lora_path
22352229
self.skip_slim = args.skip_slim
2236-
self.all_logits = args.all_logits
2230+
self.ppl = args.ppl
2231+
self.awq = args.awq
22372232
self.quant_bit = args.quant_bit
22382233
self.quant_block = args.quant_block
2239-
self.zeropoint = args.zeropoint
2234+
self.symmetric = args.sym
22402235
self.mnnconvert = args.mnnconvert
22412236
if self.tokenizer_path is None:
22422237
self.tokenizer_path = self.path
@@ -2251,10 +2246,10 @@ def init_from_args(self, args):
22512246
os.makedirs(self.onnx_path)
22522247

22532248
def load_pretrained(self, model_path: str):
2254-
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True)
2249+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True, use_fast=False)
22552250
if 'Qwen2-VL' in model_path:
22562251
from transformers import Qwen2VLForConditionalGeneration
2257-
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_path).half().eval()
2252+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_path).float().eval()
22582253
elif 'Llama-3.2' in model_path and 'Vision' in model_path:
22592254
from transformers import MllamaForConditionalGeneration
22602255
self.model = MllamaForConditionalGeneration.from_pretrained(model_path).float().eval()
@@ -2297,7 +2292,7 @@ def load_model(self, model_path):
22972292
model_mapper = ModelMapper()
22982293

22992294
self.model_type, self.model_map = model_mapper.get_map(self.config)
2300-
print(self.config, self.model_type, self.model_map, self.model)
2295+
# print(self.config, self.model_type, self.model_map, self.model)
23012296
# load config info
23022297
ModelMapper.do_map(self, self.config, self.model_map['config'])
23032298
if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None:
@@ -2420,7 +2415,7 @@ def forward(self,
24202415
hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i])
24212416
presents[i] = kv
24222417
logits = self.lm(hidden_states)
2423-
if not self.all_logits:
2418+
if not self.ppl:
24242419
logits = logits.reshape(-1)
24252420
if presents[0].shape == presents[-1].shape and None not in presents:
24262421
presents = torch.stack(presents)
@@ -2478,7 +2473,6 @@ def id_to_str(self, token_id):
24782473

24792474
def response(self, query):
24802475
# self.imitate_quant()
2481-
self.awq_quant()
24822476
prompt = self.build_prompt(query)
24832477
input_ids = self.str_to_ids(prompt)
24842478
if self.visual is not None:
@@ -2562,29 +2556,6 @@ def export_config(self, mnn_config = False):
25622556
json.dump(config, f, ensure_ascii=False, indent=4)
25632557
return config_json
25642558

2565-
def quant(self, weight, quant_bit, quant_block):
2566-
weight = weight.numpy()
2567-
oc, ic = weight.shape
2568-
if quant_block == 0:
2569-
block_size = ic
2570-
else:
2571-
block_size = quant_block
2572-
block_num = ic // block_size
2573-
weight = weight.reshape(oc, block_num, block_size)
2574-
max_val = np.max(weight, axis=-1, keepdims=True)
2575-
min_val = np.min(weight, axis=-1, keepdims=True)
2576-
offset = 1 << (quant_bit - 1)
2577-
clip_max = offset - 1
2578-
clip_min = -offset
2579-
scale = (max_val - min_val) / (clip_max - clip_min)
2580-
q_weight = np.round((weight - min_val) / scale) + clip_min
2581-
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
2582-
q_weight = q_weight.reshape(-1, 2)
2583-
if quant_bit == 4:
2584-
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
2585-
alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten()
2586-
return q_weight, alpha, clip_min
2587-
25882559
def imitate_quant(self):
25892560
def quant_dequant(linear, quant_bit = self.quant_bit, quant_block = self.quant_block):
25902561
weight = linear.weight.data
@@ -2677,13 +2648,13 @@ def export_onnx(self):
26772648
return onnx_model
26782649

26792650
def awq_quant(self):
2680-
print('### AWQ quant')
26812651
self.awq_quantizer = AwqQuantizer(self)
26822652
self.awq_quantizer.quantize()
26832653
self.is_awq_quantized = True
26842654

26852655
def export(self, export_type):
2686-
# self.awq_quant()
2656+
if self.awq:
2657+
self.awq_quant()
26872658
export_mnn = export_type == 'mnn'
26882659
# export tokenizer
26892660
self.export_tokenizer()
@@ -2751,6 +2722,14 @@ def write_header(fp, type, speicals, prefix = []):
27512722
prefix_list = []
27522723
if hasattr(self.tokenizer, 'get_prefix_tokens'):
27532724
prefix_list = self.tokenizer.get_prefix_tokens()
2725+
if len(prefix_list) == 0:
2726+
test_txt = 'A'
2727+
ids = self.tokenizer.encode(test_txt)
2728+
get_txt = self.tokenizer.decode(ids[-1])
2729+
if len(ids) > 1 and get_txt == test_txt:
2730+
prefix_list += ids[:-1]
2731+
print(prefix_list)
2732+
27542733
if self.sp_model is not None:
27552734
# senetencepiece
27562735
NORMAL = 1; UNKNOWN = 2; CONTROL = 3
@@ -3041,13 +3020,13 @@ def main():
30413020
parser.add_argument('--test', type=str, help='test model inference with query `TEST`.')
30423021
parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.')
30433022
parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.')
3044-
parser.add_argument('--all_logits', action='store_true', help='Whether or not to get all logits of input tokens.')
3045-
parser.add_argument('--quant_type', type=str, default='RTN', help='Quant type in [RTN, AWQ], default is `RTN`')
30463023
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
3047-
parser.add_argument('--zeropoint', action='store_false', help='Whether or not to zeropoint when quant weight, default is True.')
30483024
parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.')
30493025
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
30503026
parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.')
3027+
parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.')
3028+
parser.add_argument('--awq', action='store_true', help='Whether or not to use awq quant.')
3029+
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.')
30513030

30523031
args = parser.parse_args()
30533032

0 commit comments

Comments
 (0)