66import torch
77import numpy as np
88import onnxruntime as ort
9+ import _tools as MNNTools
910import sentencepiece as spm
1011from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
1112
13+ def onnx2mnn (onnx_path , mnn_dir , quant_bit = 4 , asymmetric = True , external_data = False ):
14+ model_name , model_extension = os .path .splitext (os .path .basename (onnx_path ))
15+ if model_extension != '.onnx' :
16+ return
17+ mnn_name = model_name + '.mnn'
18+ mnn_path = os .path .join (mnn_dir , mnn_name )
19+ convert_args = [
20+ '' ,
21+ '-f' ,
22+ 'ONNX' ,
23+ '--modelFile' ,
24+ str (onnx_path ),
25+ '--MNNModel' ,
26+ str (mnn_path ),
27+ '--weightQuantBits' ,
28+ str (quant_bit )
29+ ]
30+ if asymmetric :
31+ convert_args .append ("--weightQuantAsymmetric" )
32+ if external_data :
33+ convert_args .append ("--saveExternalData" )
34+ MNNTools .mnnconvert (convert_args )
35+
1236# some wrapper class for export
1337class Embedding (torch .nn .Module ):
1438 def __init__ (self , embed , using_bf16 : bool = False ):
@@ -44,7 +68,13 @@ class LLM(torch.nn.Module):
4468
4569 def __init__ (self , args ):
4670 super ().__init__ ()
47- self .export_path = args .export_path
71+ self .onnx_path = args .onnx_path
72+ self .mnn_path = args .mnn_path
73+ if not os .path .exists (self .onnx_path ):
74+ os .makedirs (self .onnx_path )
75+ if not os .path .exists (self .mnn_path ):
76+ os .makedirs (self .mnn_path )
77+ self .export_mnn = args .export_mnn
4878 self .export_verbose = args .export_verbose
4979 self .export_test = args .export_test
5080 self .embed_bf16 = args .embed_bf16
@@ -134,7 +164,7 @@ def assert_equal(self, torch_outs, onnx_outs):
134164 def export_lm (self ):
135165 model = self .lm
136166 hidden_states = torch .randn (1 , self .hidden_size )
137- onnx_model = f'./{ self .export_path } /lm.onnx'
167+ onnx_model = f'./{ self .onnx_path } /lm.onnx'
138168 torch .onnx .export (model , (hidden_states ),
139169 onnx_model ,
140170 verbose = self .export_verbose ,
@@ -151,11 +181,13 @@ def export_lm(self):
151181 }
152182 onnx_outs = ort_session .run (None , inputs )
153183 self .assert_equal (original_outs , onnx_outs )
184+ if self .export_mnn :
185+ onnx2mnn (onnx_model , self .mnn_path )
154186
155187 def export_embed (self ):
156188 model = self .embed
157189 input_ids = torch .arange (3 , dtype = torch .long )
158- onnx_model = f'./{ self .export_path } /embedding.onnx'
190+ onnx_model = f'./{ self .onnx_path } /embedding.onnx'
159191 torch .onnx .export (model , (input_ids ),
160192 onnx_model ,
161193 verbose = self .export_verbose ,
@@ -175,6 +207,8 @@ def export_embed(self):
175207 }
176208 onnx_outs = ort_session .run (None , inputs )
177209 self .assert_equal (original_outs , onnx_outs )
210+ if self .export_mnn :
211+ onnx2mnn (onnx_model , self .mnn_path )
178212
179213 def export_block (self , block_id : int ):
180214 self .seq_len = 3
@@ -184,7 +218,7 @@ def export_block(self, block_id: int):
184218 position_ids = self .get_position_ids ()
185219 past_key_values = torch .zeros (self .past_kv_shape [1 :])
186220 model = self .blocks [block_id ]
187- onnx_model = f'./{ self .export_path } /block_{ block_id } .onnx'
221+ onnx_model = f'./{ self .onnx_path } /block_{ block_id } .onnx'
188222 torch .onnx .export (
189223 model , (inputs_embeds , attention_mask , position_ids , past_key_values ),
190224 onnx_model ,
@@ -207,6 +241,8 @@ def export_block(self, block_id: int):
207241 }
208242 onnx_outs = ort_session .run (None , inputs )
209243 self .assert_equal (original_outs , onnx_outs )
244+ if self .export_mnn :
245+ onnx2mnn (onnx_model , self .mnn_path )
210246
211247 def export_blocks (self ):
212248 for i in range (self .block_nums ):
@@ -220,7 +256,7 @@ def export(self):
220256 attention_mask = self .get_attention_mask ()
221257 position_ids = self .get_position_ids ()
222258 past_key_values = torch .zeros (self .past_kv_shape )
223- onnx_model = f'./{ self .export_path } /llm.onnx'
259+ onnx_model = f'./{ self .onnx_path } /llm.onnx'
224260 torch .onnx .export (
225261 model , (input_ids , attention_mask , position_ids , past_key_values ),
226262 onnx_model ,
@@ -244,9 +280,12 @@ def export(self):
244280 }
245281 onnx_outs = ort_session .run (None , inputs )
246282 self .assert_equal (original_outs , onnx_outs )
283+ if self .export_mnn :
284+ # single model is > 2G, using external_data
285+ onnx2mnn (onnx_model , self .mnn_path , 4 , True , True )
247286
248287 def export_tokenizer (self ):
249- file_path = os .path .join (self .export_path , "tokenizer.txt" )
288+ file_path = os .path .join (self .onnx_path , "tokenizer.txt" )
250289 if self .sp_model is not None :
251290 # senetencepiece
252291 NORMAL = 1 ; UNKNOWN = 2 ; CONTROL = 3
@@ -644,7 +683,9 @@ def get_position_ids(self) -> torch.Tensor:
644683 help = 'type(`str`, *optional*):'
645684 '\n \t The pretrain llm model type.'
646685 )
647- parser .add_argument ('--export_path' , type = str , default = './onnx' , help = 'export onnx model path, defaut is `./onnx`.' )
686+ parser .add_argument ('--onnx_path' , type = str , default = './onnx' , help = 'export onnx model path, defaut is `./onnx`.' )
687+ parser .add_argument ('--mnn_path' , type = str , default = './mnn' , help = 'export mnn model path, defaut is `./mnn`.' )
688+ parser .add_argument ('--export_mnn' , action = 'store_true' , default = False , help = 'Whether or not to export mnn model after onnx.' )
648689 parser .add_argument ('--export_verbose' , action = 'store_true' , default = False , help = 'Whether or not to export onnx with verbose.' )
649690 parser .add_argument ('--export_test' , action = 'store_true' , help = 'Whether or not to export onnx with test using onnxruntime.' )
650691 parser .add_argument ('--test' , type = str , help = 'test model inference with query `TEST`.' )
0 commit comments