77import numpy as np
88from onnxslim import slim
99import onnxruntime as ort
10- import _tools as MNNTools
1110import sentencepiece as spm
1211from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
12+ try :
13+ import _tools as MNNTools
14+ except :
15+ MNNTools = None
1316
1417def onnx2mnn (onnx_path , mnn_dir , quant_bit = 4 , asymmetric = True , external_data = False , bizCode : str = None ):
1518 model_name , model_extension = os .path .splitext (os .path .basename (onnx_path ))
@@ -83,13 +86,25 @@ def __init__(self, args):
8386 self .export_mnn = args .export_mnn
8487 self .export_verbose = args .export_verbose
8588 self .export_test = args .export_test
86- self .embed_bf16 = args .embed_bf16
89+ # default is False, just set True when using below command:
90+ # `python llm_export ../path --export --embed_bin` to export single model without embedding
91+ self .without_embed = False
92+ self .embed_bin = args .embed_bin
93+ if self .embed_bin :
94+ self .embed_bf16 = True
95+ else :
96+ self .embed_bf16 = args .embed_bf16
8797 self .skip_slim = args .skip_slim
8898 tokenizer_model = os .path .join (args .path , 'tokenizer.model' )
8999 if os .path .exists (tokenizer_model ):
90100 self .sp_model = spm .SentencePieceProcessor (tokenizer_model )
91101 else :
92102 self .sp_model = None
103+ merge_file = os .path .join (args .path , 'merges.txt' )
104+ if os .path .exists (merge_file ):
105+ self .merge_txt = merge_file
106+ else :
107+ self .merge_txt = None
93108 self .stop_ids = []
94109 self .max_length = 1024
95110 self .hidden_size = 4096
@@ -111,11 +126,14 @@ def export_vocab(self):
111126 def visual_embed (self , input_ids ):
112127 raise NotImplementedError
113128
114- def forward (self , input_ids , attention_mask , position_ids , past_key_values ):
115- if self .visual is not None and past_key_values [ 0 ] is None :
116- hidden_states = self .visual_embed (input_ids )
129+ def __embedding (self , input_ids ):
130+ if self .visual is not None and self . token_len == 0 :
131+ input_embeds = self .visual_embed (input_ids )
117132 else :
118- hidden_states = self .embed (input_ids )
133+ input_embeds = self .embed (input_ids )
134+ return input_embeds
135+
136+ def __decode (self , hidden_states , attention_mask , position_ids , past_key_values ):
119137 presents = []
120138 for i in range (self .block_nums ):
121139 hidden_states , kv = self .blocks [i ](hidden_states , attention_mask , position_ids , past_key_values [i ])
@@ -126,6 +144,11 @@ def forward(self, input_ids, attention_mask, position_ids, past_key_values):
126144 self .token_len += 1
127145 return token_id , presents
128146
147+ def forward (self , input_ids , attention_mask , position_ids , past_key_values ):
148+ if self .without_embed :
149+ return self .__decode (input_ids , attention_mask , position_ids , past_key_values )
150+ return self .__decode (self .__embedding (input_ids ), attention_mask , position_ids , past_key_values )
151+
129152 # some test functions
130153 def build_prompt (self , query ):
131154 if hasattr (self .tokenizer , 'build_prompt' ):
@@ -233,6 +256,14 @@ def export_visual(self):
233256
234257 def export_embed (self ):
235258 model = self .embed
259+ if self .embed_bin :
260+ import ctypes
261+ tensor_data = model .embed .weight .data
262+ data_ptr = tensor_data .untyped_storage ().data_ptr ()
263+ buffer = (ctypes .c_byte * (tensor_data .numel () * 2 )).from_address (data_ptr )
264+ with open (f'./{ self .mnn_path } /embeddings_bf16.bin' , 'wb' ) as f :
265+ f .write (buffer )
266+ return
236267 input_ids = torch .arange (3 , dtype = torch .long )
237268 onnx_model = f'./{ self .onnx_path } /embedding.onnx'
238269 torch .onnx .export (model , (input_ids ),
@@ -308,6 +339,10 @@ def export(self):
308339 position_ids = self .get_position_ids ()
309340 past_key_values = torch .zeros (self .past_kv_shape )
310341 onnx_model = f'./{ self .onnx_path } /llm.onnx'
342+ if self .embed_bin :
343+ self .without_embed = True
344+ input_ids = self .__embedding (input_ids )
345+ print ('export start ...' )
311346 torch .onnx .export (
312347 model , (input_ids , attention_mask , position_ids , past_key_values ),
313348 onnx_model ,
@@ -319,6 +354,7 @@ def export(self):
319354 dynamic_axes = self .model_dynamic_axes ,
320355 do_constant_folding = True ,
321356 opset_version = 15 )
357+ print ('export done!' )
322358 if not self .skip_slim :
323359 slim (onnx_model , output_model = onnx_model )
324360 if self .export_test :
@@ -336,11 +372,14 @@ def export(self):
336372 if self .export_mnn :
337373 # single model is > 2G, using external_data
338374 onnx2mnn (onnx_model , self .mnn_path , self .quant_bit , self .asymmetric , True )
375+ if self .without_embed :
376+ self .without_embed = False
339377
340378 def export_tokenizer (self ):
341379 file_path = os .path .join (self .onnx_path , "tokenizer.txt" )
342380 if self .sp_model is not None :
343381 # senetencepiece
382+ print ('# senetencepiece tokenier' )
344383 NORMAL = 1 ; UNKNOWN = 2 ; CONTROL = 3
345384 USER_DEFINED = 4 ; UNUSED = 5 ; BYTE = 6
346385 fp = open (file_path , "w" , encoding = "utf8" )
@@ -365,6 +404,7 @@ def export_tokenizer(self):
365404 fp .write (f'{ token_encode } { score } { type } \n ' )
366405 fp .close ()
367406 elif hasattr (self .tokenizer , 'mergeable_ranks' ):
407+ print ('# tiktoken tokenier' )
368408 # tikton
369409 with open (file_path , "w" , encoding = "utf8" ) as fp :
370410 for k , v in self .tokenizer .mergeable_ranks .items ():
@@ -374,6 +414,25 @@ def export_tokenizer(self):
374414 for k , v in self .tokenizer .special_tokens .items ():
375415 line = base64 .b64encode (k .encode ("utf-8" )).decode ("utf8" ) + "\n "
376416 fp .write (line )
417+ elif self .merge_txt is not None :
418+ # huggingface tokenizer
419+ merge_list = []
420+ vocab = self .tokenizer .get_vocab ()
421+ vocab_list = ['<unk>' for i in range (len (vocab ))]
422+ # load vocab
423+ for k , v in vocab .items ():
424+ vocab_list [int (v )] = k
425+ # load merge
426+ with open (self .merge_txt , 'rt' ) as merge :
427+ for line in merge .readlines ():
428+ merge_list .append (line )
429+ # write to tokenizer.txt
430+ with open (file_path , "w" , encoding = "utf8" ) as fp :
431+ fp .write (f'{ len (vocab_list )} { len (merge_list )} \n ' )
432+ for v in vocab_list :
433+ fp .write (v + '\n ' )
434+ for m in merge_list :
435+ fp .write (m )
377436 else :
378437 # huggingface tokenizer
379438 def unicode_to_byte (u : int ):
@@ -706,6 +765,107 @@ def visual_embed(self, input_ids):
706765 hidden_states [i ][a + 1 : b ] = images [idx ]
707766 return hidden_states .view (- 1 , 1 , self .hidden_size )
708767
768+ class QWEN2Block (torch .nn .Module ):
769+ def __init__ (self , name , block , block_id , hidden_size , final_layernorm = None ):
770+ super ().__init__ ()
771+ self .name = name
772+ self .block = block
773+ self .block_id = block_id
774+ self .final_layernorm = final_layernorm
775+ self .hidden_size = hidden_size
776+
777+ def forward (self , hidden_states , attention_mask , position_ids , past_kv ):
778+ theta = 1.0 / (10000.0 ** (torch .arange (0 , 128 , 2 , dtype = torch .float32 ) / 128 ))
779+ position_ids = position_ids .float ().reshape (- 1 , 1 )
780+ idx_theta = position_ids * theta
781+ rotary_pos_emb = torch .cat ((idx_theta , idx_theta ), dim = - 1 )
782+ rotary_pos_emb = rotary_pos_emb .unsqueeze (0 ).unsqueeze (0 )
783+ rotary_pos_emb = torch .stack ([torch .cos (rotary_pos_emb ), torch .sin (rotary_pos_emb )])
784+ hidden_states = hidden_states .view (1 , - 1 , self .hidden_size )
785+ hidden_states , presents = self .block (hidden_states = hidden_states ,
786+ attention_mask = attention_mask ,
787+ past_key_value = past_kv ,
788+ rotary_pos_emb = rotary_pos_emb ,
789+ use_cache = True )
790+ if self .final_layernorm is not None :
791+ hidden_states = self .final_layernorm (hidden_states )
792+ hidden_states = hidden_states .view (- 1 , self .hidden_size )[- 1 ].view (1 , 1 , self .hidden_size )
793+ if isinstance (presents , tuple ):
794+ presents = torch .stack (presents )
795+ return hidden_states , presents
796+
797+ class Qwen2_Chat (LLM ):
798+ def __init__ (self , args ):
799+ super ().__init__ (args )
800+
801+ def load_model (self , model_path : str ):
802+ self .tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
803+ model = AutoModelForCausalLM .from_pretrained (model_path , trust_remote_code = True ).float ().eval ()
804+ # Qwen2 models
805+ self .model_name = 'Qwen2-7B'
806+ transformer = model .model
807+ self .lm_ = model .lm_head
808+ self .embed_ = transformer .embed_tokens
809+ self .blocks_ = transformer .layers
810+ self .final_layernorm_ = transformer .norm
811+ # some wrapper
812+ self .stop_id = self .tokenizer .eos_token_id
813+ if hasattr (model , 'generation_config' ):
814+ self .stop_ids .append (self .stop_id )
815+ for id in model .generation_config .eos_token_id :
816+ self .stop_ids .append (id )
817+ self .block_nums = len (self .blocks_ )
818+ self .hidden_size = self .embed_ .weight .shape [- 1 ]
819+ self .embed = Embedding (self .embed_ , self .embed_bf16 )
820+ self .lm = Lm (self .lm_ )
821+ self .blocks = [QWEN2Block (self .model_name , self .blocks_ [i ], i , self .hidden_size , self .final_layernorm_ if i == len (self .blocks_ ) - 1 else None ) for i in range (self .block_nums )]
822+ # 4b
823+ self .past_kv_shape = [self .block_nums , 2 , 1 , 20 , 0 , 128 ]
824+ # some config for export
825+ self .block_dynamic_axes = {
826+ "inputs_embeds" : { 0 : "seq_len" },
827+ "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
828+ "position_ids" : { 0 : "seq_len" },
829+ "past_key_values" : { 2 : "history_len" }
830+ }
831+ self .model_dynamic_axes = {
832+ "input_ids" : { 0 : "seq_len" },
833+ "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
834+ "position_ids" : { 0 : "seq_len" },
835+ "past_key_values" : { 3 : "history_len" }
836+ }
837+
838+ def build_prompt (self , query ):
839+ return f'<|im_start|>user\n { query } <|im_end|>\n <|im_start|>assistant\n '
840+
841+ def get_attention_mask (self ) -> torch .Tensor :
842+ if self .token_len :
843+ return torch .zeros ([1 , 1 , 1 , self .seq_len ], dtype = torch .float32 )
844+ return (1 - torch .tril (torch .ones ([1 , 1 , self .seq_len , self .seq_len ]))) * torch .finfo (torch .float32 ).min
845+
846+
847+ def get_position_ids (self ) -> torch .Tensor :
848+ if self .token_len :
849+ return torch .tensor ([[self .seq_len - 1 ]], dtype = torch .long )
850+ return torch .arange (self .seq_len , dtype = torch .long ).unsqueeze (0 )
851+
852+ def visual_embed (self , input_ids ):
853+ if not torch .any (input_ids == self .image_start_id ):
854+ return self .embed (input_ids )
855+ bos_pos = torch .where (input_ids == self .image_start_id )
856+ eos_pos = torch .where (input_ids == self .image_start_id + 1 )
857+ img_pos = torch .stack ((bos_pos [0 ], bos_pos [1 ], eos_pos [1 ]), dim = 1 )
858+ images = []
859+ for i , a , b in img_pos :
860+ image = input_ids [i ][a + 1 : b - 1 ].tolist ()
861+ image = image [ : image .index (self .image_start_id + 2 )]
862+ images .append (bytes (image ).decode ('utf-8' ))
863+ images = self .visual .encode (images )
864+ hidden_states = self .embed (input_ids ).view (1 , - 1 , self .hidden_size )
865+ for idx , (i , a , b ) in enumerate (img_pos ):
866+ hidden_states [i ][a + 1 : b ] = images [idx ]
867+ return hidden_states .view (- 1 , 1 , self .hidden_size )
868+
709869# llama2
710870class LLAMA2Block (torch .nn .Module ):
711871 def __init__ (self , block , block_id , hidden_size , final_layernorm = None ):
@@ -1021,6 +1181,7 @@ def export(self):
10211181 'Qwen-7B-Chat' : Qwen_Chat ,
10221182 'Qwen-1_8B-Chat' : Qwen_Chat ,
10231183 'Qwen-VL-Chat' : Qwen_Chat ,
1184+ 'Qwen1_5-4B-Chat' : Qwen2_Chat ,
10241185 'Baichuan2-7B-Chat' : Llama2_7b_Chat ,
10251186 'Llama-2-7b-chat-ms' : Llama2_7b_Chat ,
10261187 'internlm-chat-7b' : Llama2_7b_Chat ,
@@ -1059,6 +1220,7 @@ def export(self):
10591220 parser .add_argument ('--export_lm' , action = 'store_true' , help = 'export llm lm_head to an `onnx` model.' )
10601221 parser .add_argument ('--export_block' , type = int , help = 'export llm block [id] to an `onnx` model.' )
10611222 parser .add_argument ('--export_blocks' , action = 'store_true' , help = 'export llm all blocks to `onnx` models.' )
1223+ parser .add_argument ('--embed_bin' , action = 'store_true' , help = 'export embedding weight as bin file with dtype `bfloat16`' )
10621224 parser .add_argument ('--embed_bf16' , action = 'store_true' , help = 'using `bfloat16` replace `float32` in embedding.' )
10631225 parser .add_argument ('--skip_slim' , action = 'store_true' , help = 'Whether or not to skip onnx-slim.' )
10641226
@@ -1103,4 +1265,4 @@ def export(self):
11031265 llm_exporter .export_blocks ()
11041266
11051267 if args .export_block is not None :
1106- llm_exporter .export_block (args .export_block )
1268+ llm_exporter .export_block (args .export_block )
0 commit comments