Skip to content

Commit 3170a03

Browse files
committed
[feat] support mobilellm.
1 parent 35c3804 commit 3170a03

File tree

1 file changed

+80
-98
lines changed

1 file changed

+80
-98
lines changed

llmexport/llmexport.py

Lines changed: 80 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def regist_llama(self):
9494
self.regist('llama', llama_map)
9595
self.regist('qwen2', llama_map)
9696
self.regist('internlm', llama_map)
97+
self.regist('mobilellm', llama_map)
9798
# baichuan
9899
baichuan_map = copy.deepcopy(self.default_map)
99100
baichuan_map[self.attention_key] = {
@@ -365,9 +366,7 @@ def __init__(
365366
self.tokenizer = model.tokenizer
366367
self.w_bit = model.quant_bit
367368
self.group_size = model.quant_block
368-
self.zeropoint = False
369-
# self.calib_data = model.calib_data
370-
# self.split = model.split
369+
self.zeropoint = not model.symmetric
371370
self.calib_data = 'ag_news'
372371
self.split = 'test'
373372
self.duo_scaling = True
@@ -390,42 +389,20 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):
390389
w = w.reshape(-1, self.group_size)
391390
assert w.dim() == 2
392391
assert torch.isnan(w).sum() == 0
393-
394392
# zero point quantization
395393
if self.zeropoint:
396-
'''
397-
max_val = w.amax(dim=1, keepdim=True)
398-
min_val = w.amin(dim=1, keepdim=True)
399-
max_int = 2**self.w_bit - 1
400-
min_int = 0
401-
scales = (max_val - min_val).clamp(min=1e-5) / max_int
402-
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
403-
w = (
404-
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
405-
) * scales
406-
zeros = zeros.view(org_w_shape[0], -1)
407-
'''
408-
#'''
409394
max_val = w.amax(dim=1, keepdim=True)
410395
min_val = w.amin(dim=1, keepdim=True)
411396
offset = 1 << (self.w_bit - 1)
412397
clip_max = offset - 1
413398
clip_min = -offset
414399
scales = (max_val - min_val) / (clip_max - clip_min)
400+
zeros = - torch.round(min_val / scales) + clip_min
401+
qw = torch.round(w / scales) + zeros
402+
qw = torch.clamp(qw, clip_min, clip_max)
403+
w = (qw - zeros) * scales
415404
zeros = min_val.view(org_w_shape[0], -1)
416-
qw = torch.clamp(torch.round((w - min_val) / scales) + clip_min, clip_min, clip_max)
417-
w = (qw - clip_min) * scales + min_val
418-
#'''
419405
else:
420-
'''
421-
max_val = w.abs().amax(dim=1, keepdim=True)
422-
max_val = max_val.clamp(min=1e-5)
423-
max_int = 2 ** (self.w_bit - 1) - 1
424-
min_int = -(2 ** (self.w_bit - 1))
425-
scales = max_val / max_int
426-
zeros = None
427-
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
428-
'''
429406
abs_max = w.abs().amax(dim=1, keepdim=True)
430407
offset = 1 << (self.w_bit - 1)
431408
clip_max = offset - 1
@@ -464,7 +441,7 @@ def quantize(self):
464441
].to(common_device)
465442

466443
self.inps = self.inps.to(common_device)
467-
print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
444+
# print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
468445

469446
# [STEP 1]: Get layer, extract linear modules, extract input features
470447
named_linears = AwqQuantizer.get_named_linears(self.modules[i])
@@ -522,7 +499,7 @@ def quantize(self):
522499
self._search_best_scale(self.modules[i], **layer)
523500
for layer in module_config
524501
]
525-
print(scales_list); exit(0)
502+
# print(scales_list); exit(0)
526503
AwqQuantizer.apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
527504
# [STEP 3]: Compute and apply clipping list
528505
if self.apply_clip:
@@ -1316,7 +1293,7 @@ def __init__(self, onnx_path, weight_ops, config):
13161293
self.quant_block = config.quant_block
13171294
self.quant_bit = config.quant_bit
13181295
self.lm_quant_bit = config.lm_quant_bit
1319-
self.zeropoint = config.zeropoint
1296+
self.symmetric = config.symmetric
13201297
self.mnn_weight_offset = 0
13211298
self.onnx_model_path = onnx_path
13221299
self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn')
@@ -1430,7 +1407,7 @@ def rebuild(self, json_path):
14301407
json.dump(mnn_graph, file, ensure_ascii=False, indent=4)
14311408
return self.mnn_weight_path
14321409

1433-
def quant(self, weight, quant_bit, quant_block, zeropoint):
1410+
def quant(self, weight, quant_bit, quant_block, symmetric):
14341411
weight = weight.numpy()
14351412
oc, ic = weight.shape
14361413
if quant_block == 0:
@@ -1441,28 +1418,36 @@ def quant(self, weight, quant_bit, quant_block, zeropoint):
14411418
weight = weight.reshape(oc, block_num, block_size)
14421419
offset = 1 << (quant_bit - 1)
14431420
clip_max = offset - 1
1444-
if zeropoint:
1445-
clip_min = -offset
1446-
max_val = np.max(weight, axis=-1, keepdims=True)
1447-
min_val = np.min(weight, axis=-1, keepdims=True)
1448-
scale = (max_val - min_val) / (clip_max - clip_min)
1449-
1450-
# q_weight = np.round((weight - min_val) / scale) + clip_min
1451-
q_weight = np.round(weight / scale) - np.round(min_val / scale) + clip_min
1452-
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
1453-
alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten()
1454-
else:
1421+
if symmetric:
14551422
clip_min = -clip_max
14561423
abs_max = np.max(np.abs(weight), axis=-1, keepdims=True)
14571424
scale = abs_max / clip_max
14581425
q_weight = np.round(weight / scale)
14591426
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
14601427
alpha = scale.flatten()
1428+
else:
1429+
clip_min = -offset
1430+
max_val = np.max(weight, axis=-1, keepdims=True)
1431+
min_val = np.min(weight, axis=-1, keepdims=True)
1432+
scale = (max_val - min_val) / (clip_max - clip_min)
1433+
1434+
import MNN
1435+
if MNN.version() <= '2.9.6':
1436+
q_weight = np.round((weight - min_val) / scale) + clip_min
1437+
zeros = min_val
1438+
aMin = clip_min
1439+
else:
1440+
q_weight = np.round(weight / scale) - np.round(min_val / scale) + clip_min
1441+
zeros = (np.round(min_val / scale) - clip_min) * scale
1442+
aMin = 1
1443+
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
1444+
alpha = np.stack([zeros.flatten(), scale.flatten()], axis=-1).flatten()
14611445
q_weight = q_weight.reshape(-1, 2)
14621446
if quant_bit == 4:
14631447
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
14641448

1465-
return q_weight, alpha, clip_min
1449+
1450+
return q_weight, alpha, aMin
14661451

14671452
def write_npy(self, data):
14681453
return self.mnn_weight.write(data.tobytes())
@@ -1483,12 +1468,18 @@ def write_header(self, ic, oc, quant_bit):
14831468
header_length = dim_num + dim_length + map_length
14841469
return header_length, shape_dtype == np.int32
14851470

1486-
def build_weight(self, linear, quant_bit, quant_block, zeropoint):
1471+
def build_weight(self, linear, quant_bit, quant_block, symmetric):
14871472
ic, oc = linear.in_features, linear.out_features
1488-
q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block, zeropoint)
1489-
header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
1490-
weight_len = self.write_npy(q_weight) + header_len
1491-
alpha_len = self.write_npy(alpha)
1473+
if quant_bit == 16:
1474+
half_weight = linear.weight.data.half().flatten().numpy()
1475+
weight_len = self.write_npy(half_weight)
1476+
alpha_len, q_min, shape_int32 = 0, 0, False
1477+
else:
1478+
assert(quant_bit in (4, 8))
1479+
q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block, symmetric)
1480+
header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
1481+
weight_len = self.write_npy(q_weight) + header_len
1482+
alpha_len = self.write_npy(alpha)
14921483
if linear.bias is not None:
14931484
bias = linear.bias.data.flatten().numpy()
14941485
bias_length = self.write_npy(bias)
@@ -1548,7 +1539,7 @@ def rebuild_linear(self, op, graph):
15481539

15491540
quant_bit = self.lm_quant_bit if 'lm_head' in name else self.quant_bit
15501541
block_size = ic if self.quant_block == 0 else self.quant_block
1551-
external, q_min, shape_int32 = self.build_weight(linear, quant_bit, self.quant_block, self.zeropoint)
1542+
external, q_min, shape_int32 = self.build_weight(linear, quant_bit, self.quant_block, self.symmetric)
15521543

15531544
origin_input = op['inputIndexes']
15541545
origin_output = op['outputIndexes']
@@ -1589,12 +1580,22 @@ def rebuild_linear(self, op, graph):
15891580
},
15901581
"defaultDimentionFormat": "NHWC"
15911582
}
1592-
if self.zeropoint:
1593-
aMin = q_min
1594-
readType = oc * (ic // block_size)
1583+
1584+
if quant_bit == 16:
1585+
quanParameter = { "type": 3 }
15951586
else:
1596-
aMin = 0
1597-
readType = 0
1587+
if self.symmetric:
1588+
aMin = 0
1589+
readType = 0
1590+
else:
1591+
aMin = q_min
1592+
readType = oc * (ic // block_size)
1593+
1594+
quanParameter = {
1595+
"quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
1596+
"useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
1597+
"type": 1, "aMax": 0, "aMin": aMin, "readType": readType, "weightSize": 0
1598+
}
15981599
conv_op = {
15991600
"name": conv_name,
16001601
"inputIndexes": pre_convert_output,
@@ -1608,11 +1609,7 @@ def rebuild_linear(self, op, graph):
16081609
'outputCount': oc, 'relu': False, 'padMode': 'CAFFE',
16091610
'relu6': False, 'inputCount': ic, 'hasOutputShape': False
16101611
},
1611-
"quanParameter": {
1612-
"quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
1613-
"useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
1614-
"type": 1, "aMax": 0, "aMin": aMin, "readType": readType, "weightSize": 0
1615-
},
1612+
"quanParameter": quanParameter,
16161613
"external": external
16171614
},
16181615
"defaultDimentionFormat": "NHWC"
@@ -1938,10 +1935,10 @@ def __init__(self, lm_, final_layernorm_, config):
19381935
self.final_layernorm = final_layernorm_
19391936
self.lm = lm_
19401937
self.hidden_size = config.hidden_size
1941-
self.all_logits = config.all_logits
1938+
self.ppl = config.ppl
19421939

19431940
def forward(self, hidden_states):
1944-
if not self.all_logits:
1941+
if not self.ppl:
19451942
# just need last logit for predict next token
19461943
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
19471944
hidden_states = self.final_layernorm(hidden_states)
@@ -2233,10 +2230,11 @@ def init_from_args(self, args):
22332230
self.tokenizer_path = args.tokenizer_path
22342231
self.lora_path = args.lora_path
22352232
self.skip_slim = args.skip_slim
2236-
self.all_logits = args.all_logits
2233+
self.ppl = args.ppl
2234+
self.awq = args.awq
22372235
self.quant_bit = args.quant_bit
22382236
self.quant_block = args.quant_block
2239-
self.zeropoint = args.zeropoint
2237+
self.symmetric = args.sym
22402238
self.mnnconvert = args.mnnconvert
22412239
if self.tokenizer_path is None:
22422240
self.tokenizer_path = self.path
@@ -2251,10 +2249,10 @@ def init_from_args(self, args):
22512249
os.makedirs(self.onnx_path)
22522250

22532251
def load_pretrained(self, model_path: str):
2254-
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True)
2252+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True, use_fast=False)
22552253
if 'Qwen2-VL' in model_path:
22562254
from transformers import Qwen2VLForConditionalGeneration
2257-
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_path).half().eval()
2255+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_path).float().eval()
22582256
elif 'Llama-3.2' in model_path and 'Vision' in model_path:
22592257
from transformers import MllamaForConditionalGeneration
22602258
self.model = MllamaForConditionalGeneration.from_pretrained(model_path).float().eval()
@@ -2297,7 +2295,7 @@ def load_model(self, model_path):
22972295
model_mapper = ModelMapper()
22982296

22992297
self.model_type, self.model_map = model_mapper.get_map(self.config)
2300-
print(self.config, self.model_type, self.model_map, self.model)
2298+
# print(self.config, self.model_type, self.model_map, self.model)
23012299
# load config info
23022300
ModelMapper.do_map(self, self.config, self.model_map['config'])
23032301
if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None:
@@ -2420,7 +2418,7 @@ def forward(self,
24202418
hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i])
24212419
presents[i] = kv
24222420
logits = self.lm(hidden_states)
2423-
if not self.all_logits:
2421+
if not self.ppl:
24242422
logits = logits.reshape(-1)
24252423
if presents[0].shape == presents[-1].shape and None not in presents:
24262424
presents = torch.stack(presents)
@@ -2478,7 +2476,6 @@ def id_to_str(self, token_id):
24782476

24792477
def response(self, query):
24802478
# self.imitate_quant()
2481-
self.awq_quant()
24822479
prompt = self.build_prompt(query)
24832480
input_ids = self.str_to_ids(prompt)
24842481
if self.visual is not None:
@@ -2562,29 +2559,6 @@ def export_config(self, mnn_config = False):
25622559
json.dump(config, f, ensure_ascii=False, indent=4)
25632560
return config_json
25642561

2565-
def quant(self, weight, quant_bit, quant_block):
2566-
weight = weight.numpy()
2567-
oc, ic = weight.shape
2568-
if quant_block == 0:
2569-
block_size = ic
2570-
else:
2571-
block_size = quant_block
2572-
block_num = ic // block_size
2573-
weight = weight.reshape(oc, block_num, block_size)
2574-
max_val = np.max(weight, axis=-1, keepdims=True)
2575-
min_val = np.min(weight, axis=-1, keepdims=True)
2576-
offset = 1 << (quant_bit - 1)
2577-
clip_max = offset - 1
2578-
clip_min = -offset
2579-
scale = (max_val - min_val) / (clip_max - clip_min)
2580-
q_weight = np.round((weight - min_val) / scale) + clip_min
2581-
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
2582-
q_weight = q_weight.reshape(-1, 2)
2583-
if quant_bit == 4:
2584-
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
2585-
alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten()
2586-
return q_weight, alpha, clip_min
2587-
25882562
def imitate_quant(self):
25892563
def quant_dequant(linear, quant_bit = self.quant_bit, quant_block = self.quant_block):
25902564
weight = linear.weight.data
@@ -2677,13 +2651,13 @@ def export_onnx(self):
26772651
return onnx_model
26782652

26792653
def awq_quant(self):
2680-
print('### AWQ quant')
26812654
self.awq_quantizer = AwqQuantizer(self)
26822655
self.awq_quantizer.quantize()
26832656
self.is_awq_quantized = True
26842657

26852658
def export(self, export_type):
2686-
# self.awq_quant()
2659+
if self.awq:
2660+
self.awq_quant()
26872661
export_mnn = export_type == 'mnn'
26882662
# export tokenizer
26892663
self.export_tokenizer()
@@ -2751,6 +2725,14 @@ def write_header(fp, type, speicals, prefix = []):
27512725
prefix_list = []
27522726
if hasattr(self.tokenizer, 'get_prefix_tokens'):
27532727
prefix_list = self.tokenizer.get_prefix_tokens()
2728+
if len(prefix_list) == 0:
2729+
test_txt = 'A'
2730+
ids = self.tokenizer.encode(test_txt)
2731+
get_txt = self.tokenizer.decode(ids[-1])
2732+
if len(ids) > 1 and get_txt == test_txt:
2733+
prefix_list += ids[:-1]
2734+
print(prefix_list)
2735+
27542736
if self.sp_model is not None:
27552737
# senetencepiece
27562738
NORMAL = 1; UNKNOWN = 2; CONTROL = 3
@@ -3041,13 +3023,13 @@ def main():
30413023
parser.add_argument('--test', type=str, help='test model inference with query `TEST`.')
30423024
parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.')
30433025
parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.')
3044-
parser.add_argument('--all_logits', action='store_true', help='Whether or not to get all logits of input tokens.')
3045-
parser.add_argument('--quant_type', type=str, default='RTN', help='Quant type in [RTN, AWQ], default is `RTN`')
30463026
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
3047-
parser.add_argument('--zeropoint', action='store_false', help='Whether or not to zeropoint when quant weight, default is True.')
30483027
parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.')
30493028
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
30503029
parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.')
3030+
parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.')
3031+
parser.add_argument('--awq', action='store_true', help='Whether or not to use awq quant.')
3032+
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.')
30513033

30523034
args = parser.parse_args()
30533035

0 commit comments

Comments
 (0)