@@ -68,6 +68,8 @@ class LLM(torch.nn.Module):
6868
6969 def __init__ (self , args ):
7070 super ().__init__ ()
71+ self .quant_bit = 4
72+ self .asymmetric = True
7173 self .onnx_path = args .onnx_path
7274 self .mnn_path = args .mnn_path
7375 if not os .path .exists (self .onnx_path ):
@@ -182,7 +184,7 @@ def export_lm(self):
182184 onnx_outs = ort_session .run (None , inputs )
183185 self .assert_equal (original_outs , onnx_outs )
184186 if self .export_mnn :
185- onnx2mnn (onnx_model , self .mnn_path )
187+ onnx2mnn (onnx_model , self .mnn_path , self . quant_bit , self . asymmetric )
186188
187189 def export_embed (self ):
188190 model = self .embed
@@ -311,12 +313,25 @@ def export_tokenizer(self):
311313 token_encode = base64 .b64encode (token .encode ("utf-8" )).decode ("utf8" )
312314 fp .write (f'{ token_encode } { score } { type } \n ' )
313315 fp .close ()
314- else :
316+ elif hasattr ( self . tokenizer , 'mergeable_ranks' ) :
315317 # tikton
316318 with open (file_path , "w" , encoding = "utf8" ) as fp :
317319 for k , v in self .tokenizer .mergeable_ranks .items ():
318320 line = base64 .b64encode (k ).decode ("utf8" ) + "\n "
319321 fp .write (line )
322+ else :
323+ # other
324+ with open (file_path , "w" , encoding = "utf8" ) as fp :
325+ vocab = self .tokenizer .get_vocab ()
326+ vocab_list = ['<unk>' for i in range (len (vocab ))]
327+ for k , v in vocab .items ():
328+ k = k .replace ('Ċ' , '\n ' )
329+ k = k .replace ('Ġ' , ' ' )
330+ vocab_list [int (v )] = k
331+ for v in vocab_list :
332+ line = base64 .b64encode (v .encode ('utf-8' )).decode ("utf8" ) + "\n "
333+ fp .write (line )
334+
320335
321336# chatglm
322337class GLMBlock (torch .nn .Module ):
@@ -405,6 +420,7 @@ def __init__(self, block, block_id, final_layernorm = None):
405420 self .block = block
406421 self .block_id = block_id
407422 self .final_layernorm = final_layernorm
423+ self .hidden_size = 4096
408424
409425 def forward (self , hidden_states , attention_mask , position_ids , past_kv ):
410426 theta = 1.0 / (10000 ** (torch .arange (0 , 64 , 2 , dtype = torch .float32 ) / 64 ))
@@ -447,7 +463,7 @@ def load_model(self, model_path: str):
447463 self .lm = Lm (self .lm_ )
448464 self .blocks = [GLM2Block (self .blocks_ [i ], i , self .final_layernorm_ if i == len (self .blocks_ ) - 1 else None ) for i in range (self .block_nums )]
449465 # some config for export
450- self .past_kv_shape = [28 , 2 , 0 , 1 , 2 , 128 ]
466+ self .past_kv_shape = [len ( self . blocks ), 1 , 0 , 2 , 32 , 80 ]
451467 self .block_dynamic_axes = {
452468 "inputs_embeds" : { 0 : "seq_len" },
453469 "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
@@ -663,6 +679,78 @@ def get_position_ids(self) -> torch.Tensor:
663679 return torch .tensor ([[self .seq_len - 1 ]], dtype = torch .long )
664680 return torch .arange (self .seq_len , dtype = torch .long ).unsqueeze (0 )
665681
682+ class PHI2Block (torch .nn .Module ):
683+ def __init__ (self , block , block_id , hidden_size ):
684+ super ().__init__ ()
685+ self .block = block
686+ self .block_id = block_id
687+ self .hidden_size = hidden_size
688+
689+ def forward (self , hidden_states , attention_mask , position_ids , past_kv ):
690+ theta = 1.0 / (10000 ** (torch .arange (0 , 32 , 2 , dtype = torch .float32 ) / 32 ))
691+ position_ids = position_ids .float ().reshape (- 1 , 1 )
692+ idx_theta = position_ids * theta
693+ rotary_pos_emb = torch .stack ([torch .cos (idx_theta ), torch .sin (idx_theta )], dim = 0 ).contiguous ()
694+ hidden_states = hidden_states .view (1 , - 1 , self .hidden_size )
695+ hidden_states , presents = self .block (hidden_states ,
696+ past_kv ,
697+ rotary_pos_emb = rotary_pos_emb ,
698+ causal_mask = attention_mask
699+ )
700+ if self .block_id == 31 :
701+ hidden_states = hidden_states [:, - 1 , :]
702+ return hidden_states , presents
703+
704+ class phi_2 (LLM ):
705+ def __init__ (self , args ):
706+ super ().__init__ (args )
707+ self .model_name = 'phi-2'
708+ self .asymmetric = False # TODO: some precision bug when using asymmetric
709+
710+ def load_model (self , model_path : str ):
711+ self .tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
712+ model = AutoModelForCausalLM .from_pretrained (model_path , trust_remote_code = True ).float ().eval ()
713+ transformer = model .transformer
714+ self .lm_ = model .lm_head
715+ self .embed_ = transformer .embd .wte
716+ self .hidden_size = self .embed_ .weight .shape [- 1 ]
717+ self .blocks_ = transformer .h
718+ # self.final_layernorm_ = transformer.final_layernorm
719+ # some wrapper
720+ self .stop_id = self .tokenizer .eos_token_id
721+ self .block_nums = len (self .blocks_ )
722+ self .embed = Embedding (self .embed_ , self .embed_bf16 )
723+ self .lm = Lm (self .lm_ )
724+ self .blocks = [PHI2Block (self .blocks_ [i ], i , self .hidden_size ) for i in range (self .block_nums )]
725+ # some config for export
726+ self .past_kv_shape = [len (self .blocks ), 1 , 0 , 2 , 32 , 80 ]
727+ self .block_dynamic_axes = {
728+ "inputs_embeds" : { 0 : "seq_len" },
729+ "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
730+ "position_ids" : { 0 : "seq_len" },
731+ "past_key_values" : { 1 : "history_len" }
732+ }
733+ self .model_dynamic_axes = {
734+ "input_ids" : { 0 : "seq_len" },
735+ "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
736+ "position_ids" : { 0 : "seq_len" },
737+ "past_key_values" : { 2 : "history_len" }
738+ }
739+
740+ def build_prompt (self , query ):
741+ return f'Instruct: { query } \n Output:'
742+
743+ def get_attention_mask (self ) -> torch .Tensor :
744+ if self .token_len :
745+ return torch .zeros ([1 , 1 , 1 , 1 ]).bool ()
746+ attention_mask = ~ torch .tril (torch .ones ([1 , 1 , self .seq_len , self .seq_len ]).bool ())
747+ return attention_mask
748+
749+ def get_position_ids (self ) -> torch .Tensor :
750+ if self .token_len :
751+ return torch .tensor ([[self .seq_len - 1 ]], dtype = torch .long )
752+ return torch .arange (self .seq_len , dtype = torch .long ).unsqueeze (0 )
753+
666754if __name__ == '__main__' :
667755 llm_models = {
668756 'chatglm-6b' : Chatglm_6b ,
@@ -672,7 +760,8 @@ def get_position_ids(self) -> torch.Tensor:
672760 'Qwen-7B-Chat' : Qwen_7b_Chat ,
673761 'Qwen-1_8B-Chat' : Qwen_7b_Chat ,
674762 'Baichuan2-7B-Chat' : Llama2_7b_Chat ,
675- 'Llama-2-7b-chat-ms' : Llama2_7b_Chat
763+ 'Llama-2-7b-chat-ms' : Llama2_7b_Chat ,
764+ 'phi-2' : phi_2
676765 }
677766 parser = argparse .ArgumentParser (description = 'LLMExporter' , formatter_class = argparse .RawTextHelpFormatter )
678767 parser .add_argument ('--path' , type = str , default = 'THUDM/chatglm-6b' , required = True ,
0 commit comments