@@ -94,6 +94,7 @@ def regist_llama(self):
9494 self .regist ('llama' , llama_map )
9595 self .regist ('qwen2' , llama_map )
9696 self .regist ('internlm' , llama_map )
97+ self .regist ('mobilellm' , llama_map )
9798 # baichuan
9899 baichuan_map = copy .deepcopy (self .default_map )
99100 baichuan_map [self .attention_key ] = {
@@ -365,9 +366,7 @@ def __init__(
365366 self .tokenizer = model .tokenizer
366367 self .w_bit = model .quant_bit
367368 self .group_size = model .quant_block
368- self .zeropoint = False
369- # self.calib_data = model.calib_data
370- # self.split = model.split
369+ self .zeropoint = not model .symmetric
371370 self .calib_data = 'ag_news'
372371 self .split = 'test'
373372 self .duo_scaling = True
@@ -390,42 +389,20 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):
390389 w = w .reshape (- 1 , self .group_size )
391390 assert w .dim () == 2
392391 assert torch .isnan (w ).sum () == 0
393-
394392 # zero point quantization
395393 if self .zeropoint :
396- '''
397- max_val = w.amax(dim=1, keepdim=True)
398- min_val = w.amin(dim=1, keepdim=True)
399- max_int = 2**self.w_bit - 1
400- min_int = 0
401- scales = (max_val - min_val).clamp(min=1e-5) / max_int
402- zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
403- w = (
404- torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
405- ) * scales
406- zeros = zeros.view(org_w_shape[0], -1)
407- '''
408- #'''
409394 max_val = w .amax (dim = 1 , keepdim = True )
410395 min_val = w .amin (dim = 1 , keepdim = True )
411396 offset = 1 << (self .w_bit - 1 )
412397 clip_max = offset - 1
413398 clip_min = - offset
414399 scales = (max_val - min_val ) / (clip_max - clip_min )
400+ zeros = - torch .round (min_val / scales ) + clip_min
401+ qw = torch .round (w / scales ) + zeros
402+ qw = torch .clamp (qw , clip_min , clip_max )
403+ w = (qw - zeros ) * scales
415404 zeros = min_val .view (org_w_shape [0 ], - 1 )
416- qw = torch .clamp (torch .round ((w - min_val ) / scales ) + clip_min , clip_min , clip_max )
417- w = (qw - clip_min ) * scales + min_val
418- #'''
419405 else :
420- '''
421- max_val = w.abs().amax(dim=1, keepdim=True)
422- max_val = max_val.clamp(min=1e-5)
423- max_int = 2 ** (self.w_bit - 1) - 1
424- min_int = -(2 ** (self.w_bit - 1))
425- scales = max_val / max_int
426- zeros = None
427- w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
428- '''
429406 abs_max = w .abs ().amax (dim = 1 , keepdim = True )
430407 offset = 1 << (self .w_bit - 1 )
431408 clip_max = offset - 1
@@ -464,7 +441,7 @@ def quantize(self):
464441 ].to (common_device )
465442
466443 self .inps = self .inps .to (common_device )
467- print (f'# { i } inps shape: { self .inps .shape } , inps.max: { self .inps .max ()} ' )
444+ # print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
468445
469446 # [STEP 1]: Get layer, extract linear modules, extract input features
470447 named_linears = AwqQuantizer .get_named_linears (self .modules [i ])
@@ -522,7 +499,7 @@ def quantize(self):
522499 self ._search_best_scale (self .modules [i ], ** layer )
523500 for layer in module_config
524501 ]
525- print (scales_list ); exit (0 )
502+ # print(scales_list); exit(0)
526503 AwqQuantizer .apply_scale (self .modules [i ], scales_list , input_feat_dict = input_feat )
527504 # [STEP 3]: Compute and apply clipping list
528505 if self .apply_clip :
@@ -1316,7 +1293,7 @@ def __init__(self, onnx_path, weight_ops, config):
13161293 self .quant_block = config .quant_block
13171294 self .quant_bit = config .quant_bit
13181295 self .lm_quant_bit = config .lm_quant_bit
1319- self .zeropoint = config .zeropoint
1296+ self .symmetric = config .symmetric
13201297 self .mnn_weight_offset = 0
13211298 self .onnx_model_path = onnx_path
13221299 self .mnn_name = os .path .basename (onnx_path ).replace ('.onnx' , '.mnn' )
@@ -1430,7 +1407,7 @@ def rebuild(self, json_path):
14301407 json .dump (mnn_graph , file , ensure_ascii = False , indent = 4 )
14311408 return self .mnn_weight_path
14321409
1433- def quant (self , weight , quant_bit , quant_block , zeropoint ):
1410+ def quant (self , weight , quant_bit , quant_block , symmetric ):
14341411 weight = weight .numpy ()
14351412 oc , ic = weight .shape
14361413 if quant_block == 0 :
@@ -1441,28 +1418,36 @@ def quant(self, weight, quant_bit, quant_block, zeropoint):
14411418 weight = weight .reshape (oc , block_num , block_size )
14421419 offset = 1 << (quant_bit - 1 )
14431420 clip_max = offset - 1
1444- if zeropoint :
1445- clip_min = - offset
1446- max_val = np .max (weight , axis = - 1 , keepdims = True )
1447- min_val = np .min (weight , axis = - 1 , keepdims = True )
1448- scale = (max_val - min_val ) / (clip_max - clip_min )
1449-
1450- # q_weight = np.round((weight - min_val) / scale) + clip_min
1451- q_weight = np .round (weight / scale ) - np .round (min_val / scale ) + clip_min
1452- q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
1453- alpha = np .stack ([min_val .flatten (), scale .flatten ()], axis = - 1 ).flatten ()
1454- else :
1421+ if symmetric :
14551422 clip_min = - clip_max
14561423 abs_max = np .max (np .abs (weight ), axis = - 1 , keepdims = True )
14571424 scale = abs_max / clip_max
14581425 q_weight = np .round (weight / scale )
14591426 q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
14601427 alpha = scale .flatten ()
1428+ else :
1429+ clip_min = - offset
1430+ max_val = np .max (weight , axis = - 1 , keepdims = True )
1431+ min_val = np .min (weight , axis = - 1 , keepdims = True )
1432+ scale = (max_val - min_val ) / (clip_max - clip_min )
1433+
1434+ import MNN
1435+ if MNN .version () <= '2.9.6' :
1436+ q_weight = np .round ((weight - min_val ) / scale ) + clip_min
1437+ zeros = min_val
1438+ aMin = clip_min
1439+ else :
1440+ q_weight = np .round (weight / scale ) - np .round (min_val / scale ) + clip_min
1441+ zeros = (np .round (min_val / scale ) - clip_min ) * scale
1442+ aMin = 1
1443+ q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
1444+ alpha = np .stack ([zeros .flatten (), scale .flatten ()], axis = - 1 ).flatten ()
14611445 q_weight = q_weight .reshape (- 1 , 2 )
14621446 if quant_bit == 4 :
14631447 q_weight = q_weight [:, 0 ] * 16 + q_weight [:, 1 ]
14641448
1465- return q_weight , alpha , clip_min
1449+
1450+ return q_weight , alpha , aMin
14661451
14671452 def write_npy (self , data ):
14681453 return self .mnn_weight .write (data .tobytes ())
@@ -1483,12 +1468,18 @@ def write_header(self, ic, oc, quant_bit):
14831468 header_length = dim_num + dim_length + map_length
14841469 return header_length , shape_dtype == np .int32
14851470
1486- def build_weight (self , linear , quant_bit , quant_block , zeropoint ):
1471+ def build_weight (self , linear , quant_bit , quant_block , symmetric ):
14871472 ic , oc = linear .in_features , linear .out_features
1488- q_weight , alpha , q_min = self .quant (linear .weight .data , quant_bit , quant_block , zeropoint )
1489- header_len , shape_int32 = self .write_header (ic , oc , quant_bit )
1490- weight_len = self .write_npy (q_weight ) + header_len
1491- alpha_len = self .write_npy (alpha )
1473+ if quant_bit == 16 :
1474+ half_weight = linear .weight .data .half ().flatten ().numpy ()
1475+ weight_len = self .write_npy (half_weight )
1476+ alpha_len , q_min , shape_int32 = 0 , 0 , False
1477+ else :
1478+ assert (quant_bit in (4 , 8 ))
1479+ q_weight , alpha , q_min = self .quant (linear .weight .data , quant_bit , quant_block , symmetric )
1480+ header_len , shape_int32 = self .write_header (ic , oc , quant_bit )
1481+ weight_len = self .write_npy (q_weight ) + header_len
1482+ alpha_len = self .write_npy (alpha )
14921483 if linear .bias is not None :
14931484 bias = linear .bias .data .flatten ().numpy ()
14941485 bias_length = self .write_npy (bias )
@@ -1548,7 +1539,7 @@ def rebuild_linear(self, op, graph):
15481539
15491540 quant_bit = self .lm_quant_bit if 'lm_head' in name else self .quant_bit
15501541 block_size = ic if self .quant_block == 0 else self .quant_block
1551- external , q_min , shape_int32 = self .build_weight (linear , quant_bit , self .quant_block , self .zeropoint )
1542+ external , q_min , shape_int32 = self .build_weight (linear , quant_bit , self .quant_block , self .symmetric )
15521543
15531544 origin_input = op ['inputIndexes' ]
15541545 origin_output = op ['outputIndexes' ]
@@ -1589,12 +1580,22 @@ def rebuild_linear(self, op, graph):
15891580 },
15901581 "defaultDimentionFormat" : "NHWC"
15911582 }
1592- if self . zeropoint :
1593- aMin = q_min
1594- readType = oc * ( ic // block_size )
1583+
1584+ if quant_bit == 16 :
1585+ quanParameter = { "type" : 3 }
15951586 else :
1596- aMin = 0
1597- readType = 0
1587+ if self .symmetric :
1588+ aMin = 0
1589+ readType = 0
1590+ else :
1591+ aMin = q_min
1592+ readType = oc * (ic // block_size )
1593+
1594+ quanParameter = {
1595+ "quantScale" : 1.0 , "scaleIn" : 0.0 , "scaleOut" : 0.0 ,
1596+ "useInt32" : False , "has_scaleInt" : False , "shapeInt32" : shape_int32 ,
1597+ "type" : 1 , "aMax" : 0 , "aMin" : aMin , "readType" : readType , "weightSize" : 0
1598+ }
15981599 conv_op = {
15991600 "name" : conv_name ,
16001601 "inputIndexes" : pre_convert_output ,
@@ -1608,11 +1609,7 @@ def rebuild_linear(self, op, graph):
16081609 'outputCount' : oc , 'relu' : False , 'padMode' : 'CAFFE' ,
16091610 'relu6' : False , 'inputCount' : ic , 'hasOutputShape' : False
16101611 },
1611- "quanParameter" : {
1612- "quantScale" : 1.0 , "scaleIn" : 0.0 , "scaleOut" : 0.0 ,
1613- "useInt32" : False , "has_scaleInt" : False , "shapeInt32" : shape_int32 ,
1614- "type" : 1 , "aMax" : 0 , "aMin" : aMin , "readType" : readType , "weightSize" : 0
1615- },
1612+ "quanParameter" : quanParameter ,
16161613 "external" : external
16171614 },
16181615 "defaultDimentionFormat" : "NHWC"
@@ -1938,10 +1935,10 @@ def __init__(self, lm_, final_layernorm_, config):
19381935 self .final_layernorm = final_layernorm_
19391936 self .lm = lm_
19401937 self .hidden_size = config .hidden_size
1941- self .all_logits = config .all_logits
1938+ self .ppl = config .ppl
19421939
19431940 def forward (self , hidden_states ):
1944- if not self .all_logits :
1941+ if not self .ppl :
19451942 # just need last logit for predict next token
19461943 hidden_states = hidden_states .view (- 1 , self .hidden_size )[- 1 ].view (1 , 1 , self .hidden_size )
19471944 hidden_states = self .final_layernorm (hidden_states )
@@ -2233,10 +2230,11 @@ def init_from_args(self, args):
22332230 self .tokenizer_path = args .tokenizer_path
22342231 self .lora_path = args .lora_path
22352232 self .skip_slim = args .skip_slim
2236- self .all_logits = args .all_logits
2233+ self .ppl = args .ppl
2234+ self .awq = args .awq
22372235 self .quant_bit = args .quant_bit
22382236 self .quant_block = args .quant_block
2239- self .zeropoint = args .zeropoint
2237+ self .symmetric = args .sym
22402238 self .mnnconvert = args .mnnconvert
22412239 if self .tokenizer_path is None :
22422240 self .tokenizer_path = self .path
@@ -2251,10 +2249,10 @@ def init_from_args(self, args):
22512249 os .makedirs (self .onnx_path )
22522250
22532251 def load_pretrained (self , model_path : str ):
2254- self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_path , trust_remote_code = True )
2252+ self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_path , trust_remote_code = True , use_fast = False )
22552253 if 'Qwen2-VL' in model_path :
22562254 from transformers import Qwen2VLForConditionalGeneration
2257- self .model = Qwen2VLForConditionalGeneration .from_pretrained (model_path ).half ().eval ()
2255+ self .model = Qwen2VLForConditionalGeneration .from_pretrained (model_path ).float ().eval ()
22582256 elif 'Llama-3.2' in model_path and 'Vision' in model_path :
22592257 from transformers import MllamaForConditionalGeneration
22602258 self .model = MllamaForConditionalGeneration .from_pretrained (model_path ).float ().eval ()
@@ -2297,7 +2295,7 @@ def load_model(self, model_path):
22972295 model_mapper = ModelMapper ()
22982296
22992297 self .model_type , self .model_map = model_mapper .get_map (self .config )
2300- print (self .config , self .model_type , self .model_map , self .model )
2298+ # print(self.config, self.model_type, self.model_map, self.model)
23012299 # load config info
23022300 ModelMapper .do_map (self , self .config , self .model_map ['config' ])
23032301 if not hasattr (self , 'num_key_value_heads' ) or self .num_key_value_heads is None :
@@ -2420,7 +2418,7 @@ def forward(self,
24202418 hidden_states , kv = self .blocks [i ](hidden_states , rotary_pos_emb , attention_mask , past_key_values [i ])
24212419 presents [i ] = kv
24222420 logits = self .lm (hidden_states )
2423- if not self .all_logits :
2421+ if not self .ppl :
24242422 logits = logits .reshape (- 1 )
24252423 if presents [0 ].shape == presents [- 1 ].shape and None not in presents :
24262424 presents = torch .stack (presents )
@@ -2478,7 +2476,6 @@ def id_to_str(self, token_id):
24782476
24792477 def response (self , query ):
24802478 # self.imitate_quant()
2481- self .awq_quant ()
24822479 prompt = self .build_prompt (query )
24832480 input_ids = self .str_to_ids (prompt )
24842481 if self .visual is not None :
@@ -2562,29 +2559,6 @@ def export_config(self, mnn_config = False):
25622559 json .dump (config , f , ensure_ascii = False , indent = 4 )
25632560 return config_json
25642561
2565- def quant (self , weight , quant_bit , quant_block ):
2566- weight = weight .numpy ()
2567- oc , ic = weight .shape
2568- if quant_block == 0 :
2569- block_size = ic
2570- else :
2571- block_size = quant_block
2572- block_num = ic // block_size
2573- weight = weight .reshape (oc , block_num , block_size )
2574- max_val = np .max (weight , axis = - 1 , keepdims = True )
2575- min_val = np .min (weight , axis = - 1 , keepdims = True )
2576- offset = 1 << (quant_bit - 1 )
2577- clip_max = offset - 1
2578- clip_min = - offset
2579- scale = (max_val - min_val ) / (clip_max - clip_min )
2580- q_weight = np .round ((weight - min_val ) / scale ) + clip_min
2581- q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
2582- q_weight = q_weight .reshape (- 1 , 2 )
2583- if quant_bit == 4 :
2584- q_weight = q_weight [:, 0 ] * 16 + q_weight [:, 1 ]
2585- alpha = np .stack ([min_val .flatten (), scale .flatten ()], axis = - 1 ).flatten ()
2586- return q_weight , alpha , clip_min
2587-
25882562 def imitate_quant (self ):
25892563 def quant_dequant (linear , quant_bit = self .quant_bit , quant_block = self .quant_block ):
25902564 weight = linear .weight .data
@@ -2677,13 +2651,13 @@ def export_onnx(self):
26772651 return onnx_model
26782652
26792653 def awq_quant (self ):
2680- print ('### AWQ quant' )
26812654 self .awq_quantizer = AwqQuantizer (self )
26822655 self .awq_quantizer .quantize ()
26832656 self .is_awq_quantized = True
26842657
26852658 def export (self , export_type ):
2686- # self.awq_quant()
2659+ if self .awq :
2660+ self .awq_quant ()
26872661 export_mnn = export_type == 'mnn'
26882662 # export tokenizer
26892663 self .export_tokenizer ()
@@ -2751,6 +2725,14 @@ def write_header(fp, type, speicals, prefix = []):
27512725 prefix_list = []
27522726 if hasattr (self .tokenizer , 'get_prefix_tokens' ):
27532727 prefix_list = self .tokenizer .get_prefix_tokens ()
2728+ if len (prefix_list ) == 0 :
2729+ test_txt = 'A'
2730+ ids = self .tokenizer .encode (test_txt )
2731+ get_txt = self .tokenizer .decode (ids [- 1 ])
2732+ if len (ids ) > 1 and get_txt == test_txt :
2733+ prefix_list += ids [:- 1 ]
2734+ print (prefix_list )
2735+
27542736 if self .sp_model is not None :
27552737 # senetencepiece
27562738 NORMAL = 1 ; UNKNOWN = 2 ; CONTROL = 3
@@ -3041,13 +3023,13 @@ def main():
30413023 parser .add_argument ('--test' , type = str , help = 'test model inference with query `TEST`.' )
30423024 parser .add_argument ('--export' , type = str , default = None , help = 'export model to an onnx/mnn model.' )
30433025 parser .add_argument ('--skip_slim' , action = 'store_true' , help = 'Whether or not to skip onnx-slim.' )
3044- parser .add_argument ('--all_logits' , action = 'store_true' , help = 'Whether or not to get all logits of input tokens.' )
3045- parser .add_argument ('--quant_type' , type = str , default = 'RTN' , help = 'Quant type in [RTN, AWQ], default is `RTN`' )
30463026 parser .add_argument ('--quant_bit' , type = int , default = 4 , help = 'mnn quant bit, 4 or 8, default is 4.' )
3047- parser .add_argument ('--zeropoint' , action = 'store_false' , help = 'Whether or not to zeropoint when quant weight, default is True.' )
30483027 parser .add_argument ('--quant_block' , type = int , default = 128 , help = 'mnn quant block, default is 0 mean channle-wise.' )
30493028 parser .add_argument ('--lm_quant_bit' , type = int , default = None , help = 'mnn lm_head quant bit, 4 or 8, default is `quant_bit`.' )
30503029 parser .add_argument ('--mnnconvert' , type = str , default = '../../../build/MNNConvert' , help = 'local mnnconvert path, if invalid, using pymnn.' )
3030+ parser .add_argument ('--ppl' , action = 'store_true' , help = 'Whether or not to get all logits of input tokens.' )
3031+ parser .add_argument ('--awq' , action = 'store_true' , help = 'Whether or not to use awq quant.' )
3032+ parser .add_argument ('--sym' , action = 'store_true' , help = 'Whether or not to using symmetric quant (without zeropoint), defualt is False.' )
30513033
30523034 args = parser .parse_args ()
30533035
0 commit comments