@@ -82,6 +82,15 @@ def regist(self, model_type, model_map):
8282
8383 def regist_models (self ):
8484 self .defualt_map ()
85+ # regist models
86+ self .regist_llama ()
87+ self .regist_qwen ()
88+ self .regist_glm ()
89+ self .regist_glm2 ()
90+ self .regist_phi ()
91+ self .regist_gemma2 ()
92+
93+ def regist_llama (self ):
8594 llama_map = self .default_map
8695 self .regist ('llama' , llama_map )
8796 self .regist ('qwen2' , llama_map )
@@ -92,10 +101,6 @@ def regist_models(self):
92101 'o_proj' : 'o_proj'
93102 }
94103 self .regist ('baichuan' , baichuan_map )
95- self .regist_qwen ()
96- self .regist_glm ()
97- self .regist_glm2 ()
98- self .regist_phi ()
99104
100105 def regist_qwen (self ):
101106 qwen_map = {
@@ -204,6 +209,20 @@ def regist_phi(self):
204209 }
205210 self .regist ('phi-msft' , phi_map )
206211
212+ def regist_gemma2 (self ):
213+ gemma2_config = copy .deepcopy (self .default_config )
214+ gemma2_config ['head_dim' ] = 'head_dim'
215+ gemma2_decoder = copy .deepcopy (self .default_decoder )
216+ gemma2_decoder ['pre_feedforward_layernorm' ] = 'pre_feedforward_layernorm'
217+ gemma2_decoder ['post_feedforward_layernorm' ] = 'post_feedforward_layernorm'
218+ gemma2_map = {
219+ 'config' : gemma2_config ,
220+ 'model' : self .defualt_model ,
221+ 'decoder' : gemma2_decoder ,
222+ 'attention' : self .default_attention
223+ }
224+ self .regist ('gemma2' , gemma2_map )
225+
207226 def defualt_map (self ):
208227 # default map is `LlamaForCausalLM`
209228 self .config_key = 'config'
@@ -356,10 +375,11 @@ def rebuild(self):
356375 return self .onnx_weight_path
357376
358377class MNNConveter :
359- def __init__ (self , onnx_path , weight_ops , quant_bit = 4 , quant_block = 0 ):
378+ def __init__ (self , onnx_path , weight_ops , config ):
360379 self .weight_ops = weight_ops
361- self .quant_block = quant_block
362- self .quant_bit = quant_bit
380+ self .quant_block = config .quant_block
381+ self .quant_bit = config .quant_bit
382+ self .lm_quant_bit = config .lm_quant_bit
363383 self .mnn_weight_offset = 0
364384 self .onnx_model_path = onnx_path
365385 self .mnn_model_path = onnx_path .replace ('.onnx' , '.mnn' )
@@ -458,49 +478,49 @@ def rebuild(self, json_path):
458478 json .dump (mnn_graph , file , ensure_ascii = False , indent = 4 )
459479 return self .mnn_weight_path
460480
461- def quant (self , weight ):
481+ def quant (self , weight , quant_bit , quant_block ):
462482 weight = weight .numpy ()
463483 oc , ic = weight .shape
464- if self . quant_block == 0 :
484+ if quant_block == 0 :
465485 block_size = ic
466486 else :
467- block_size = self . quant_block
487+ block_size = quant_block
468488 block_num = ic // block_size
469489 weight = weight .reshape (oc , block_num , block_size )
470490 max_val = np .max (weight , axis = - 1 , keepdims = True )
471491 min_val = np .min (weight , axis = - 1 , keepdims = True )
472- offset = 1 << (self . quant_bit - 1 )
492+ offset = 1 << (quant_bit - 1 )
473493 clip_max = offset - 1
474494 clip_min = - offset
475495 scale = (max_val - min_val ) / (clip_max - clip_min )
476496 q_weight = np .round ((weight - min_val ) / scale ).astype (np .int8 ) + clip_min
477497 q_weight = (np .clip (q_weight .flatten (), clip_min , clip_max ) + offset ).astype (np .uint8 )
478498 q_weight = q_weight .reshape (- 1 , 2 )
479- if self . quant_bit == 4 :
499+ if quant_bit == 4 :
480500 q_weight = q_weight [:, 0 ] * 16 + q_weight [:, 1 ]
481501 alpha = np .stack ([min_val .flatten (), scale .flatten ()], axis = - 1 ).flatten ()
482502 return q_weight , alpha , clip_min
483503
484504 def write_npy (self , data ):
485505 return self .mnn_weight .write (data .tobytes ())
486506
487- def write_header (self , ic , oc ):
507+ def write_header (self , ic , oc , quant_bit ):
488508 dim_num = self .mnn_weight .write (b'\x02 ' )
489509 shape_dtype = np .int16
490510 if oc > 65535 or ic > 65535 :
491511 shape_dtype = np .int32
492512 dim_length = self .write_npy (np .array ([oc , ic ]).astype (shape_dtype ))
493- offset = 1 << (self . quant_bit - 1 )
513+ offset = 1 << (quant_bit - 1 )
494514 weight_map = [i for i in range (- offset , offset )]
495515 weight_map .insert (0 , len (weight_map ))
496516 map_length = self .write_npy (np .array (weight_map , dtype = np .int8 ))
497517 header_length = dim_num + dim_length + map_length
498518 return header_length , shape_dtype == np .int32
499519
500- def build_weight (self , linear ):
520+ def build_weight (self , linear , quant_bit , quant_block ):
501521 ic , oc = linear .in_features , linear .out_features
502- q_weight , alpha , q_min = self .quant (linear .weight .data )
503- header_len , shape_int32 = self .write_header (ic , oc )
522+ q_weight , alpha , q_min = self .quant (linear .weight .data , quant_bit , quant_block )
523+ header_len , shape_int32 = self .write_header (ic , oc , quant_bit )
504524 weight_len = self .write_npy (q_weight ) + header_len
505525 alpha_len = self .write_npy (alpha )
506526 if linear .bias is not None :
@@ -535,7 +555,10 @@ def rebuild_op(self, op, graph):
535555 linear .out_features == oc and
536556 (linear .bias is not None ) == has_bias )
537557
538- external , q_min , shape_int32 = self .build_weight (linear )
558+
559+ quant_bit = self .lm_quant_bit if 'lm_head' in name else self .quant_bit
560+ print (f'quant layer { name } : bits { quant_bit } , block { self .quant_block } ' )
561+ external , q_min , shape_int32 = self .build_weight (linear , quant_bit , self .quant_block )
539562
540563 origin_input = op ['inputIndexes' ]
541564 origin_output = op ['outputIndexes' ]
@@ -630,9 +653,13 @@ def __init__(self, embed, config):
630653 super ().__init__ ()
631654 self .hidden_size = config .hidden_size
632655 self .embed = embed
656+ if config .model_type == 'gemma2' :
657+ normalizer = torch .tensor (self .hidden_size ** 0.5 )
658+ self .embed .weight .data *= normalizer
633659
634660 def forward (self , input_ids ):
635- return self .embed (input_ids ).view (- 1 , 1 , self .hidden_size )
661+ inputs_embeds = self .embed (input_ids ).view (- 1 , 1 , self .hidden_size )
662+ return inputs_embeds
636663
637664def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
638665 batch , num_key_value_heads , slen , head_dim = hidden_states .shape
@@ -695,6 +722,7 @@ def forward(
695722 query_states = self .q_proj (hidden_states )
696723 key_states = self .k_proj (hidden_states )
697724 value_states = self .v_proj (hidden_states )
725+ # print(f'hidden_states.shape = {hidden_states.shape}, query_states.shape = {query_states.shape}')
698726 query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
699727 key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
700728 value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
@@ -734,7 +762,7 @@ def forward(
734762 attn_output = torch .matmul (attn_weights , value_states )
735763
736764 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
737- attn_output = attn_output .reshape (bsz , q_len , self . hidden_size )
765+ attn_output = attn_output .reshape (bsz , q_len , - 1 )
738766 attn_output = self .o_proj (attn_output )
739767 return attn_output , past_key_value
740768
@@ -832,22 +860,32 @@ def forward(
832860 past_key_value = past_key_value ,
833861 )
834862 # Fully Connected
835- if self .alpha != 1.0 :
863+ if not hasattr (self , 'post_attention_layernorm' ):
864+ # phi
865+ feed_forward_hidden_states = self .mlp (norm_hidden_states )
866+ hidden_states = hidden_states + feed_forward_hidden_states + residual
867+ elif self .alpha != 1.0 :
836868 # chatglm-6b
837869 hidden_states = norm_hidden_states * self .alpha + hidden_states
838870 mlp_input = self .post_attention_layernorm (hidden_states )
839871 mlp_output = self .mlp (mlp_input )
840872 hidden_states = mlp_input * self .alpha + mlp_output
841- elif hasattr (self , 'post_attention_layernorm' ):
873+ elif hasattr (self , 'pre_feedforward_layernorm' ):
874+ # gemma2
875+ hidden_states = self .post_attention_layernorm (hidden_states )
842876 hidden_states = residual + hidden_states
843877 residual = hidden_states
844- hidden_states = self .post_attention_layernorm (hidden_states )
878+ hidden_states = self .pre_feedforward_layernorm (hidden_states )
845879 hidden_states = self .mlp (hidden_states )
880+ hidden_states = self .post_feedforward_layernorm (hidden_states )
846881 hidden_states = residual + hidden_states
847882 else :
848- # phi
849- feed_forward_hidden_states = self .mlp (norm_hidden_states )
850- hidden_states = hidden_states + feed_forward_hidden_states + residual
883+ # general
884+ hidden_states = residual + hidden_states
885+ residual = hidden_states
886+ hidden_states = self .post_attention_layernorm (hidden_states )
887+ hidden_states = self .mlp (hidden_states )
888+ hidden_states = residual + hidden_states
851889
852890 return hidden_states , present_key_value
853891
@@ -886,6 +924,10 @@ def init_from_args(self, args):
886924 self .skip_slim = args .skip_slim
887925 self .quant_bit = args .quant_bit
888926 self .quant_block = args .quant_block
927+ if args .lm_quant_bit is not None :
928+ self .lm_quant_bit = args .lm_quant_bit
929+ else :
930+ self .lm_quant_bit = self .quant_bit
889931 # init export dst dir
890932 if not os .path .exists (self .dst_path ):
891933 os .makedirs (self .dst_path )
@@ -920,6 +962,7 @@ def load_model(self, model_path):
920962 self .stop_ids .append (id )
921963 self .stop_ids = [stop_id for stop_id in self .stop_ids if stop_id is not None ]
922964 model_mapper = ModelMapper ()
965+
923966 self .model_type , self .model_map = model_mapper .get_map (self .config )
924967 # print(self.model)
925968 # print(self.model_type, self.model_map)
@@ -929,7 +972,8 @@ def load_model(self, model_path):
929972 self .num_key_value_heads = self .num_attention_heads
930973 if not hasattr (self , 'rope_theta' ) or self .rope_theta is None :
931974 self .rope_theta = 10000.0
932- self .head_dim = self .hidden_size // self .num_attention_heads
975+ if not hasattr (self , 'head_dim' ) or self .head_dim is None :
976+ self .head_dim = self .hidden_size // self .num_attention_heads
933977 # some export info
934978 self .past_kv_shape = [self .num_hidden_layers , 2 , 1 , 0 , self .num_key_value_heads , self .head_dim ]
935979 self .block_dynamic_axes = {
@@ -1082,6 +1126,8 @@ def build_prompt(self, query):
10821126 return f'{ query } [gMASK]<sop>'
10831127 if 'phi-2' in self .model_name :
10841128 return f'Instruct: { query } \n Output:'
1129+ if 'gemma-2' in self .model_name :
1130+ return f'<bos><start_of_turn>user\n { query } <end_of_turn>\n <start_of_turn>model\n '
10851131 return query
10861132
10871133 def str_to_ids (self , prompt ):
@@ -1246,7 +1292,7 @@ def export(self, export_type):
12461292 self .onnx_slim (onnx_model )
12471293 if export_mnn :
12481294 # convert onnx to mnn and quant weight
1249- MNNConveter (onnx_model , self .unloaded_ops , self . quant_bit , self . quant_block ).export ()
1295+ MNNConveter (onnx_model , self .unloaded_ops , self ).export ()
12501296 else :
12511297 # export weight to llm.onnx.data
12521298 self .onnx_load_param (onnx_model )
@@ -1479,7 +1525,7 @@ def export(self, export_type):
14791525 if not self .skip_slim :
14801526 self .onnx_slim (onnx_model )
14811527 if 'mnn' in export_type :
1482- MNNConveter (onnx_model , None , self . quant_bit , self . quant_block ).export ()
1528+ MNNConveter (onnx_model , None , self ).export ()
14831529
14841530 def build_prompt (self , query ):
14851531 return f'[CLS]{ query } [SEP]'
@@ -1506,7 +1552,8 @@ def get_attention_mask(self) -> torch.Tensor:
15061552 parser .add_argument ('--export' , type = str , default = None , help = 'export model to an onnx/mnn model.' )
15071553 parser .add_argument ('--skip_slim' , action = 'store_true' , help = 'Whether or not to skip onnx-slim.' )
15081554 parser .add_argument ('--quant_bit' , type = int , default = 4 , help = 'mnn quant bit, 4 or 8, default is 4.' )
1509- parser .add_argument ('--quant_block' , type = int , default = 0 , help = 'mnn quant block, default is 0 mean channle-wise.' )
1555+ parser .add_argument ('--quant_block' , type = int , default = 128 , help = 'mnn quant block, default is 0 mean channle-wise.' )
1556+ parser .add_argument ('--lm_quant_bit' , type = int , default = None , help = 'mnn lm_head quant bit, 4 or 8, default is `quant_bit`.' )
15101557
15111558 args = parser .parse_args ()
15121559 model_path = args .path
0 commit comments