|
5 | 5 | import argparse |
6 | 6 | import torch |
7 | 7 | import numpy as np |
| 8 | +from onnxslim import slim |
8 | 9 | import onnxruntime as ort |
9 | 10 | import _tools as MNNTools |
10 | 11 | import sentencepiece as spm |
@@ -83,7 +84,7 @@ def __init__(self, args): |
83 | 84 | self.export_verbose = args.export_verbose |
84 | 85 | self.export_test = args.export_test |
85 | 86 | self.embed_bf16 = args.embed_bf16 |
86 | | - self.slim = args.slim |
| 87 | + self.skip_slim = args.skip_slim |
87 | 88 | tokenizer_model = os.path.join(args.path, 'tokenizer.model') |
88 | 89 | if os.path.exists(tokenizer_model): |
89 | 90 | self.sp_model = spm.SentencePieceProcessor(tokenizer_model) |
@@ -186,8 +187,7 @@ def export_lm(self): |
186 | 187 | output_names=['token_id'], |
187 | 188 | do_constant_folding=True, |
188 | 189 | opset_version=15) |
189 | | - if self.slim: |
190 | | - from onnxslim import slim |
| 190 | + if not self.skip_slim: |
191 | 191 | slim(onnx_model, output_model=onnx_model) |
192 | 192 | # test lm |
193 | 193 | if self.export_test: |
@@ -217,8 +217,7 @@ def export_visual(self): |
217 | 217 | }}, |
218 | 218 | do_constant_folding=True, |
219 | 219 | opset_version=15) |
220 | | - if self.slim: |
221 | | - from onnxslim import slim |
| 220 | + if not self.skip_slim: |
222 | 221 | slim(onnx_model, output_model=onnx_model) |
223 | 222 | # test |
224 | 223 | if self.export_test: |
@@ -246,8 +245,7 @@ def export_embed(self): |
246 | 245 | }}, |
247 | 246 | do_constant_folding=True, |
248 | 247 | opset_version=15) |
249 | | - if self.slim: |
250 | | - from onnxslim import slim |
| 248 | + if not self.skip_slim: |
251 | 249 | slim(onnx_model, output_model=onnx_model) |
252 | 250 | # test |
253 | 251 | if self.export_test: |
@@ -281,8 +279,7 @@ def export_block(self, block_id: int): |
281 | 279 | dynamic_axes=self.block_dynamic_axes, |
282 | 280 | do_constant_folding=True, |
283 | 281 | opset_version=15) |
284 | | - if self.slim: |
285 | | - from onnxslim import slim |
| 282 | + if not self.skip_slim: |
286 | 283 | slim(onnx_model, output_model=onnx_model) |
287 | 284 | if self.export_test: |
288 | 285 | original_outs = model(inputs_embeds, attention_mask, position_ids, past_key_values) |
@@ -322,8 +319,7 @@ def export(self): |
322 | 319 | dynamic_axes=self.model_dynamic_axes, |
323 | 320 | do_constant_folding=True, |
324 | 321 | opset_version=15) |
325 | | - if self.slim: |
326 | | - from onnxslim import slim |
| 322 | + if not self.skip_slim: |
327 | 323 | slim(onnx_model, output_model=onnx_model) |
328 | 324 | if self.export_test: |
329 | 325 | # test |
@@ -961,8 +957,7 @@ def export(self): |
961 | 957 | dynamic_axes=self.model_dynamic_axes, |
962 | 958 | do_constant_folding=True, |
963 | 959 | opset_version=15) |
964 | | - if self.slim: |
965 | | - from onnxslim import slim |
| 960 | + if not self.skip_slim: |
966 | 961 | slim(onnx_model, output_model=onnx_model) |
967 | 962 | if self.export_test: |
968 | 963 | self.seq_len = 4 |
@@ -1042,7 +1037,7 @@ def get_attention_mask(self) -> torch.Tensor: |
1042 | 1037 | parser.add_argument('--export_block', type=int, help='export llm block [id] to an `onnx` model.') |
1043 | 1038 | parser.add_argument('--export_blocks', action='store_true', help='export llm all blocks to `onnx` models.') |
1044 | 1039 | parser.add_argument('--embed_bf16', action='store_true', help='using `bfloat16` replace `float32` in embedding.') |
1045 | | - parser.add_argument('--slim', action='store_true', help='Whether or not to slim the exported onnx model.') |
| 1040 | + parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.') |
1046 | 1041 |
|
1047 | 1042 |
|
1048 | 1043 | args = parser.parse_args() |
|
0 commit comments