@@ -94,6 +94,7 @@ def regist_llama(self):
9494 self .regist ('llama' , llama_map )
9595 self .regist ('qwen2' , llama_map )
9696 self .regist ('internlm' , llama_map )
97+ self .regist ('mobilellm' , llama_map )
9798 # baichuan
9899 baichuan_map = copy .deepcopy (self .default_map )
99100 baichuan_map [self .attention_key ] = {
@@ -365,9 +366,7 @@ def __init__(
365366 self .tokenizer = model .tokenizer
366367 self .w_bit = model .quant_bit
367368 self .group_size = model .quant_block
368- self .zeropoint = False
369- # self.calib_data = model.calib_data
370- # self.split = model.split
369+ self .zeropoint = not model .symmetric
371370 self .calib_data = 'ag_news'
372371 self .split = 'test'
373372 self .duo_scaling = True
@@ -390,42 +389,20 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):
390389 w = w .reshape (- 1 , self .group_size )
391390 assert w .dim () == 2
392391 assert torch .isnan (w ).sum () == 0
393-
394392 # zero point quantization
395393 if self .zeropoint :
396- '''
397- max_val = w.amax(dim=1, keepdim=True)
398- min_val = w.amin(dim=1, keepdim=True)
399- max_int = 2**self.w_bit - 1
400- min_int = 0
401- scales = (max_val - min_val).clamp(min=1e-5) / max_int
402- zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
403- w = (
404- torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
405- ) * scales
406- zeros = zeros.view(org_w_shape[0], -1)
407- '''
408- #'''
409394 max_val = w .amax (dim = 1 , keepdim = True )
410395 min_val = w .amin (dim = 1 , keepdim = True )
411396 offset = 1 << (self .w_bit - 1 )
412397 clip_max = offset - 1
413398 clip_min = - offset
414399 scales = (max_val - min_val ) / (clip_max - clip_min )
400+ zeros = - torch .round (min_val / scales ) + clip_min
401+ qw = torch .round (w / scales ) + zeros
402+ qw = torch .clamp (qw , clip_min , clip_max )
403+ w = (qw - zeros ) * scales
415404 zeros = min_val .view (org_w_shape [0 ], - 1 )
416- qw = torch .clamp (torch .round ((w - min_val ) / scales ) + clip_min , clip_min , clip_max )
417- w = (qw - clip_min ) * scales + min_val
418- #'''
419405 else :
420- '''
421- max_val = w.abs().amax(dim=1, keepdim=True)
422- max_val = max_val.clamp(min=1e-5)
423- max_int = 2 ** (self.w_bit - 1) - 1
424- min_int = -(2 ** (self.w_bit - 1))
425- scales = max_val / max_int
426- zeros = None
427- w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
428- '''
429406 abs_max = w .abs ().amax (dim = 1 , keepdim = True )
430407 offset = 1 << (self .w_bit - 1 )
431408 clip_max = offset - 1
@@ -464,7 +441,7 @@ def quantize(self):
464441 ].to (common_device )
465442
466443 self .inps = self .inps .to (common_device )
467- print (f'# { i } inps shape: { self .inps .shape } , inps.max: { self .inps .max ()} ' )
444+ # print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
468445
469446 # [STEP 1]: Get layer, extract linear modules, extract input features
470447 named_linears = AwqQuantizer .get_named_linears (self .modules [i ])
@@ -522,7 +499,7 @@ def quantize(self):
522499 self ._search_best_scale (self .modules [i ], ** layer )
523500 for layer in module_config
524501 ]
525- print (scales_list ); exit (0 )
502+ # print(scales_list); exit(0)
526503 AwqQuantizer .apply_scale (self .modules [i ], scales_list , input_feat_dict = input_feat )
527504 # [STEP 3]: Compute and apply clipping list
528505 if self .apply_clip :
@@ -1316,7 +1293,7 @@ def __init__(self, onnx_path, weight_ops, config):
13161293 self .quant_block = config .quant_block
13171294 self .quant_bit = config .quant_bit
13181295 self .lm_quant_bit = config .lm_quant_bit
1319- self .zeropoint = config .zeropoint
1296+ self .symmetric = config .symmetric
13201297 self .mnn_weight_offset = 0
13211298 self .onnx_model_path = onnx_path
13221299 self .mnn_name = os .path .basename (onnx_path ).replace ('.onnx' , '.mnn' )
@@ -1430,7 +1407,7 @@ def rebuild(self, json_path):
14301407 json .dump (mnn_graph , file , ensure_ascii = False , indent = 4 )
14311408 return self .mnn_weight_path
14321409
1433- def quant (self , weight , quant_bit , quant_block , zeropoint ):
1410+ def quant (self , weight , quant_bit , quant_block , symmetric ):
14341411 weight = weight .numpy ()
14351412 oc , ic = weight .shape
14361413 if quant_block == 0 :
@@ -1441,27 +1418,32 @@ def quant(self, weight, quant_bit, quant_block, zeropoint):
14411418 weight = weight .reshape (oc , block_num , block_size )
14421419 offset = 1 << (quant_bit - 1 )
14431420 clip_max = offset - 1
1444- if zeropoint :
1445- clip_min = - offset
1446- max_val = np .max (weight , axis = - 1 , keepdims = True )
1447- min_val = np .min (weight , axis = - 1 , keepdims = True )
1448- scale = (max_val - min_val ) / (clip_max - clip_min )
1449-
1450- # q_weight = np.round((weight - min_val) / scale) + clip_min
1451- q_weight = np .round (weight / scale ) - np .round (min_val / scale ) + clip_min
1452- q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
1453- alpha = np .stack ([min_val .flatten (), scale .flatten ()], axis = - 1 ).flatten ()
1454- else :
1421+ if symmetric :
14551422 clip_min = - clip_max
14561423 abs_max = np .max (np .abs (weight ), axis = - 1 , keepdims = True )
14571424 scale = abs_max / clip_max
14581425 q_weight = np .round (weight / scale )
14591426 q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
14601427 alpha = scale .flatten ()
1428+ else :
1429+ clip_min = - offset
1430+ max_val = np .max (weight , axis = - 1 , keepdims = True )
1431+ min_val = np .min (weight , axis = - 1 , keepdims = True )
1432+ scale = (max_val - min_val ) / (clip_max - clip_min )
1433+
1434+ if False :
1435+ q_weight = np .round ((weight - min_val ) / scale ) + clip_min
1436+ zeros = min_val - scale * clip_min
1437+ else :
1438+ q_weight = np .round (weight / scale ) - np .round (min_val / scale ) + clip_min
1439+ zeros = (np .round (min_val / scale ) - clip_min ) * scale
1440+ q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
1441+ alpha = np .stack ([zeros .flatten (), scale .flatten ()], axis = - 1 ).flatten ()
14611442 q_weight = q_weight .reshape (- 1 , 2 )
14621443 if quant_bit == 4 :
14631444 q_weight = q_weight [:, 0 ] * 16 + q_weight [:, 1 ]
14641445
1446+ clip_min = 1
14651447 return q_weight , alpha , clip_min
14661448
14671449 def write_npy (self , data ):
@@ -1483,12 +1465,18 @@ def write_header(self, ic, oc, quant_bit):
14831465 header_length = dim_num + dim_length + map_length
14841466 return header_length , shape_dtype == np .int32
14851467
1486- def build_weight (self , linear , quant_bit , quant_block , zeropoint ):
1468+ def build_weight (self , linear , quant_bit , quant_block , symmetric ):
14871469 ic , oc = linear .in_features , linear .out_features
1488- q_weight , alpha , q_min = self .quant (linear .weight .data , quant_bit , quant_block , zeropoint )
1489- header_len , shape_int32 = self .write_header (ic , oc , quant_bit )
1490- weight_len = self .write_npy (q_weight ) + header_len
1491- alpha_len = self .write_npy (alpha )
1470+ if quant_bit == 16 :
1471+ half_weight = linear .weight .data .half ().flatten ().numpy ()
1472+ weight_len = self .write_npy (half_weight )
1473+ alpha_len , q_min , shape_int32 = 0 , 0 , False
1474+ else :
1475+ assert (quant_bit in (4 , 8 ))
1476+ q_weight , alpha , q_min = self .quant (linear .weight .data , quant_bit , quant_block , symmetric )
1477+ header_len , shape_int32 = self .write_header (ic , oc , quant_bit )
1478+ weight_len = self .write_npy (q_weight ) + header_len
1479+ alpha_len = self .write_npy (alpha )
14921480 if linear .bias is not None :
14931481 bias = linear .bias .data .flatten ().numpy ()
14941482 bias_length = self .write_npy (bias )
@@ -1548,7 +1536,7 @@ def rebuild_linear(self, op, graph):
15481536
15491537 quant_bit = self .lm_quant_bit if 'lm_head' in name else self .quant_bit
15501538 block_size = ic if self .quant_block == 0 else self .quant_block
1551- external , q_min , shape_int32 = self .build_weight (linear , quant_bit , self .quant_block , self .zeropoint )
1539+ external , q_min , shape_int32 = self .build_weight (linear , quant_bit , self .quant_block , self .symmetric )
15521540
15531541 origin_input = op ['inputIndexes' ]
15541542 origin_output = op ['outputIndexes' ]
@@ -1589,12 +1577,22 @@ def rebuild_linear(self, op, graph):
15891577 },
15901578 "defaultDimentionFormat" : "NHWC"
15911579 }
1592- if self . zeropoint :
1593- aMin = q_min
1594- readType = oc * ( ic // block_size )
1580+
1581+ if quant_bit == 16 :
1582+ quanParameter = { "type" : 3 }
15951583 else :
1596- aMin = 0
1597- readType = 0
1584+ if self .symmetric :
1585+ aMin = 0
1586+ readType = 0
1587+ else :
1588+ aMin = q_min
1589+ readType = oc * (ic // block_size )
1590+
1591+ quanParameter = {
1592+ "quantScale" : 1.0 , "scaleIn" : 0.0 , "scaleOut" : 0.0 ,
1593+ "useInt32" : False , "has_scaleInt" : False , "shapeInt32" : shape_int32 ,
1594+ "type" : 1 , "aMax" : 0 , "aMin" : aMin , "readType" : readType , "weightSize" : 0
1595+ }
15981596 conv_op = {
15991597 "name" : conv_name ,
16001598 "inputIndexes" : pre_convert_output ,
@@ -1608,11 +1606,7 @@ def rebuild_linear(self, op, graph):
16081606 'outputCount' : oc , 'relu' : False , 'padMode' : 'CAFFE' ,
16091607 'relu6' : False , 'inputCount' : ic , 'hasOutputShape' : False
16101608 },
1611- "quanParameter" : {
1612- "quantScale" : 1.0 , "scaleIn" : 0.0 , "scaleOut" : 0.0 ,
1613- "useInt32" : False , "has_scaleInt" : False , "shapeInt32" : shape_int32 ,
1614- "type" : 1 , "aMax" : 0 , "aMin" : aMin , "readType" : readType , "weightSize" : 0
1615- },
1609+ "quanParameter" : quanParameter ,
16161610 "external" : external
16171611 },
16181612 "defaultDimentionFormat" : "NHWC"
@@ -1938,10 +1932,10 @@ def __init__(self, lm_, final_layernorm_, config):
19381932 self .final_layernorm = final_layernorm_
19391933 self .lm = lm_
19401934 self .hidden_size = config .hidden_size
1941- self .all_logits = config .all_logits
1935+ self .ppl = config .ppl
19421936
19431937 def forward (self , hidden_states ):
1944- if not self .all_logits :
1938+ if not self .ppl :
19451939 # just need last logit for predict next token
19461940 hidden_states = hidden_states .view (- 1 , self .hidden_size )[- 1 ].view (1 , 1 , self .hidden_size )
19471941 hidden_states = self .final_layernorm (hidden_states )
@@ -2233,10 +2227,11 @@ def init_from_args(self, args):
22332227 self .tokenizer_path = args .tokenizer_path
22342228 self .lora_path = args .lora_path
22352229 self .skip_slim = args .skip_slim
2236- self .all_logits = args .all_logits
2230+ self .ppl = args .ppl
2231+ self .awq = args .awq
22372232 self .quant_bit = args .quant_bit
22382233 self .quant_block = args .quant_block
2239- self .zeropoint = args .zeropoint
2234+ self .symmetric = args .sym
22402235 self .mnnconvert = args .mnnconvert
22412236 if self .tokenizer_path is None :
22422237 self .tokenizer_path = self .path
@@ -2251,10 +2246,10 @@ def init_from_args(self, args):
22512246 os .makedirs (self .onnx_path )
22522247
22532248 def load_pretrained (self , model_path : str ):
2254- self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_path , trust_remote_code = True )
2249+ self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_path , trust_remote_code = True , use_fast = False )
22552250 if 'Qwen2-VL' in model_path :
22562251 from transformers import Qwen2VLForConditionalGeneration
2257- self .model = Qwen2VLForConditionalGeneration .from_pretrained (model_path ).half ().eval ()
2252+ self .model = Qwen2VLForConditionalGeneration .from_pretrained (model_path ).float ().eval ()
22582253 elif 'Llama-3.2' in model_path and 'Vision' in model_path :
22592254 from transformers import MllamaForConditionalGeneration
22602255 self .model = MllamaForConditionalGeneration .from_pretrained (model_path ).float ().eval ()
@@ -2297,7 +2292,7 @@ def load_model(self, model_path):
22972292 model_mapper = ModelMapper ()
22982293
22992294 self .model_type , self .model_map = model_mapper .get_map (self .config )
2300- print (self .config , self .model_type , self .model_map , self .model )
2295+ # print(self.config, self.model_type, self.model_map, self.model)
23012296 # load config info
23022297 ModelMapper .do_map (self , self .config , self .model_map ['config' ])
23032298 if not hasattr (self , 'num_key_value_heads' ) or self .num_key_value_heads is None :
@@ -2420,7 +2415,7 @@ def forward(self,
24202415 hidden_states , kv = self .blocks [i ](hidden_states , rotary_pos_emb , attention_mask , past_key_values [i ])
24212416 presents [i ] = kv
24222417 logits = self .lm (hidden_states )
2423- if not self .all_logits :
2418+ if not self .ppl :
24242419 logits = logits .reshape (- 1 )
24252420 if presents [0 ].shape == presents [- 1 ].shape and None not in presents :
24262421 presents = torch .stack (presents )
@@ -2478,7 +2473,6 @@ def id_to_str(self, token_id):
24782473
24792474 def response (self , query ):
24802475 # self.imitate_quant()
2481- self .awq_quant ()
24822476 prompt = self .build_prompt (query )
24832477 input_ids = self .str_to_ids (prompt )
24842478 if self .visual is not None :
@@ -2562,29 +2556,6 @@ def export_config(self, mnn_config = False):
25622556 json .dump (config , f , ensure_ascii = False , indent = 4 )
25632557 return config_json
25642558
2565- def quant (self , weight , quant_bit , quant_block ):
2566- weight = weight .numpy ()
2567- oc , ic = weight .shape
2568- if quant_block == 0 :
2569- block_size = ic
2570- else :
2571- block_size = quant_block
2572- block_num = ic // block_size
2573- weight = weight .reshape (oc , block_num , block_size )
2574- max_val = np .max (weight , axis = - 1 , keepdims = True )
2575- min_val = np .min (weight , axis = - 1 , keepdims = True )
2576- offset = 1 << (quant_bit - 1 )
2577- clip_max = offset - 1
2578- clip_min = - offset
2579- scale = (max_val - min_val ) / (clip_max - clip_min )
2580- q_weight = np .round ((weight - min_val ) / scale ) + clip_min
2581- q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
2582- q_weight = q_weight .reshape (- 1 , 2 )
2583- if quant_bit == 4 :
2584- q_weight = q_weight [:, 0 ] * 16 + q_weight [:, 1 ]
2585- alpha = np .stack ([min_val .flatten (), scale .flatten ()], axis = - 1 ).flatten ()
2586- return q_weight , alpha , clip_min
2587-
25882559 def imitate_quant (self ):
25892560 def quant_dequant (linear , quant_bit = self .quant_bit , quant_block = self .quant_block ):
25902561 weight = linear .weight .data
@@ -2677,13 +2648,13 @@ def export_onnx(self):
26772648 return onnx_model
26782649
26792650 def awq_quant (self ):
2680- print ('### AWQ quant' )
26812651 self .awq_quantizer = AwqQuantizer (self )
26822652 self .awq_quantizer .quantize ()
26832653 self .is_awq_quantized = True
26842654
26852655 def export (self , export_type ):
2686- # self.awq_quant()
2656+ if self .awq :
2657+ self .awq_quant ()
26872658 export_mnn = export_type == 'mnn'
26882659 # export tokenizer
26892660 self .export_tokenizer ()
@@ -2751,6 +2722,14 @@ def write_header(fp, type, speicals, prefix = []):
27512722 prefix_list = []
27522723 if hasattr (self .tokenizer , 'get_prefix_tokens' ):
27532724 prefix_list = self .tokenizer .get_prefix_tokens ()
2725+ if len (prefix_list ) == 0 :
2726+ test_txt = 'A'
2727+ ids = self .tokenizer .encode (test_txt )
2728+ get_txt = self .tokenizer .decode (ids [- 1 ])
2729+ if len (ids ) > 1 and get_txt == test_txt :
2730+ prefix_list += ids [:- 1 ]
2731+ print (prefix_list )
2732+
27542733 if self .sp_model is not None :
27552734 # senetencepiece
27562735 NORMAL = 1 ; UNKNOWN = 2 ; CONTROL = 3
@@ -3041,13 +3020,13 @@ def main():
30413020 parser .add_argument ('--test' , type = str , help = 'test model inference with query `TEST`.' )
30423021 parser .add_argument ('--export' , type = str , default = None , help = 'export model to an onnx/mnn model.' )
30433022 parser .add_argument ('--skip_slim' , action = 'store_true' , help = 'Whether or not to skip onnx-slim.' )
3044- parser .add_argument ('--all_logits' , action = 'store_true' , help = 'Whether or not to get all logits of input tokens.' )
3045- parser .add_argument ('--quant_type' , type = str , default = 'RTN' , help = 'Quant type in [RTN, AWQ], default is `RTN`' )
30463023 parser .add_argument ('--quant_bit' , type = int , default = 4 , help = 'mnn quant bit, 4 or 8, default is 4.' )
3047- parser .add_argument ('--zeropoint' , action = 'store_false' , help = 'Whether or not to zeropoint when quant weight, default is True.' )
30483024 parser .add_argument ('--quant_block' , type = int , default = 128 , help = 'mnn quant block, default is 0 mean channle-wise.' )
30493025 parser .add_argument ('--lm_quant_bit' , type = int , default = None , help = 'mnn lm_head quant bit, 4 or 8, default is `quant_bit`.' )
30503026 parser .add_argument ('--mnnconvert' , type = str , default = '../../../build/MNNConvert' , help = 'local mnnconvert path, if invalid, using pymnn.' )
3027+ parser .add_argument ('--ppl' , action = 'store_true' , help = 'Whether or not to get all logits of input tokens.' )
3028+ parser .add_argument ('--awq' , action = 'store_true' , help = 'Whether or not to use awq quant.' )
3029+ parser .add_argument ('--sym' , action = 'store_true' , help = 'Whether or not to using symmetric quant (without zeropoint), defualt is False.' )
30513030
30523031 args = parser .parse_args ()
30533032
0 commit comments