@@ -311,12 +311,25 @@ def export_tokenizer(self):
311311 token_encode = base64 .b64encode (token .encode ("utf-8" )).decode ("utf8" )
312312 fp .write (f'{ token_encode } { score } { type } \n ' )
313313 fp .close ()
314- else :
314+ elif hasattr ( self . tokenizer , 'mergeable_ranks' ) :
315315 # tikton
316316 with open (file_path , "w" , encoding = "utf8" ) as fp :
317317 for k , v in self .tokenizer .mergeable_ranks .items ():
318318 line = base64 .b64encode (k ).decode ("utf8" ) + "\n "
319319 fp .write (line )
320+ else :
321+ # other
322+ with open (file_path , "w" , encoding = "utf8" ) as fp :
323+ vocab = self .tokenizer .get_vocab ()
324+ vocab_list = ['<unk>' for i in range (len (vocab ))]
325+ for k , v in vocab .items ():
326+ k = k .replace ('Ċ' , '\n ' )
327+ k = k .replace ('Ġ' , ' ' )
328+ vocab_list [int (v )] = k
329+ for v in vocab_list :
330+ line = base64 .b64encode (v .encode ('utf-8' )).decode ("utf8" ) + "\n "
331+ fp .write (line )
332+
320333
321334# chatglm
322335class GLMBlock (torch .nn .Module ):
@@ -405,6 +418,7 @@ def __init__(self, block, block_id, final_layernorm = None):
405418 self .block = block
406419 self .block_id = block_id
407420 self .final_layernorm = final_layernorm
421+ self .hidden_size = 4096
408422
409423 def forward (self , hidden_states , attention_mask , position_ids , past_kv ):
410424 theta = 1.0 / (10000 ** (torch .arange (0 , 64 , 2 , dtype = torch .float32 ) / 64 ))
@@ -447,7 +461,7 @@ def load_model(self, model_path: str):
447461 self .lm = Lm (self .lm_ )
448462 self .blocks = [GLM2Block (self .blocks_ [i ], i , self .final_layernorm_ if i == len (self .blocks_ ) - 1 else None ) for i in range (self .block_nums )]
449463 # some config for export
450- self .past_kv_shape = [28 , 2 , 0 , 1 , 2 , 128 ]
464+ self .past_kv_shape = [len ( self . blocks ), 1 , 0 , 2 , 32 , 80 ]
451465 self .block_dynamic_axes = {
452466 "inputs_embeds" : { 0 : "seq_len" },
453467 "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
@@ -663,6 +677,77 @@ def get_position_ids(self) -> torch.Tensor:
663677 return torch .tensor ([[self .seq_len - 1 ]], dtype = torch .long )
664678 return torch .arange (self .seq_len , dtype = torch .long ).unsqueeze (0 )
665679
680+ class PHI2Block (torch .nn .Module ):
681+ def __init__ (self , block , block_id , hidden_size ):
682+ super ().__init__ ()
683+ self .block = block
684+ self .block_id = block_id
685+ self .hidden_size = hidden_size
686+
687+ def forward (self , hidden_states , attention_mask , position_ids , past_kv ):
688+ theta = 1.0 / (10000 ** (torch .arange (0 , 32 , 2 , dtype = torch .float32 ) / 32 ))
689+ position_ids = position_ids .float ().reshape (- 1 , 1 )
690+ idx_theta = position_ids * theta
691+ rotary_pos_emb = torch .stack ([torch .cos (idx_theta ), torch .sin (idx_theta )], dim = 0 ).contiguous ()
692+ hidden_states = hidden_states .view (1 , - 1 , self .hidden_size )
693+ hidden_states , presents = self .block (hidden_states ,
694+ past_kv ,
695+ rotary_pos_emb = rotary_pos_emb ,
696+ causal_mask = attention_mask
697+ )
698+ if self .block_id == 31 :
699+ hidden_states = hidden_states [:, - 1 , :]
700+ return hidden_states , presents
701+
702+ class phi_2 (LLM ):
703+ def __init__ (self , args ):
704+ super ().__init__ (args )
705+ self .model_name = 'phi-2'
706+
707+ def load_model (self , model_path : str ):
708+ self .tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
709+ model = AutoModelForCausalLM .from_pretrained (model_path , trust_remote_code = True ).float ().eval ()
710+ transformer = model .transformer
711+ self .lm_ = model .lm_head
712+ self .embed_ = transformer .embd .wte
713+ self .hidden_size = self .embed_ .weight .shape [- 1 ]
714+ self .blocks_ = transformer .h
715+ # self.final_layernorm_ = transformer.final_layernorm
716+ # some wrapper
717+ self .stop_id = self .tokenizer .eos_token_id
718+ self .block_nums = len (self .blocks_ )
719+ self .embed = Embedding (self .embed_ , self .embed_bf16 )
720+ self .lm = Lm (self .lm_ )
721+ self .blocks = [PHI2Block (self .blocks_ [i ], i , self .hidden_size ) for i in range (self .block_nums )]
722+ # some config for export
723+ self .past_kv_shape = [len (self .blocks ), 1 , 0 , 2 , 32 , 80 ]
724+ self .block_dynamic_axes = {
725+ "inputs_embeds" : { 0 : "seq_len" },
726+ "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
727+ "position_ids" : { 0 : "seq_len" },
728+ "past_key_values" : { 1 : "history_len" }
729+ }
730+ self .model_dynamic_axes = {
731+ "input_ids" : { 0 : "seq_len" },
732+ "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
733+ "position_ids" : { 0 : "seq_len" },
734+ "past_key_values" : { 2 : "history_len" }
735+ }
736+
737+ def build_prompt (self , query ):
738+ return f'Instruct: { query } \n Output:'
739+
740+ def get_attention_mask (self ) -> torch .Tensor :
741+ if self .token_len :
742+ return torch .zeros ([1 , 1 , 1 , 1 ]).bool ()
743+ attention_mask = ~ torch .tril (torch .ones ([1 , 1 , self .seq_len , self .seq_len ]).bool ())
744+ return attention_mask
745+
746+ def get_position_ids (self ) -> torch .Tensor :
747+ if self .token_len :
748+ return torch .tensor ([[self .seq_len - 1 ]], dtype = torch .long )
749+ return torch .arange (self .seq_len , dtype = torch .long ).unsqueeze (0 )
750+
666751if __name__ == '__main__' :
667752 llm_models = {
668753 'chatglm-6b' : Chatglm_6b ,
@@ -672,7 +757,8 @@ def get_position_ids(self) -> torch.Tensor:
672757 'Qwen-7B-Chat' : Qwen_7b_Chat ,
673758 'Qwen-1_8B-Chat' : Qwen_7b_Chat ,
674759 'Baichuan2-7B-Chat' : Llama2_7b_Chat ,
675- 'Llama-2-7b-chat-ms' : Llama2_7b_Chat
760+ 'Llama-2-7b-chat-ms' : Llama2_7b_Chat ,
761+ 'phi-2' : phi_2
676762 }
677763 parser = argparse .ArgumentParser (description = 'LLMExporter' , formatter_class = argparse .RawTextHelpFormatter )
678764 parser .add_argument ('--path' , type = str , default = 'THUDM/chatglm-6b' , required = True ,
0 commit comments