Skip to content

Commit a1362df

Browse files
committed
support export mnn model.
1 parent 29fd8dd commit a1362df

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ git clone https://modelscope.cn/ZhipuAI/chatglm2-6b.git
4848
3. 执行LLMExporter导出模型
4949
```sh
5050
cd LLMExporter
51-
python llm_export.py --path ../chatglm2-6b --export_path ./onnx --export
51+
python llm_export.py --path ../chatglm2-6b --onnx_path ./onnx --export_split --export_mnn --export_token
5252
```
5353

5454
## 功能
@@ -61,11 +61,13 @@ python llm_export.py --path ../chatglm2-6b --export_path ./onnx --export
6161
- 支持对模型进行对话测试,使用`--test $query`会返回llm的回复内容
6262
- 支持在导出onnx模型后使用onnxruntime对结果一致性进行校验,使用`--export_test`
6363
- 支持将tokenizer导出为文本文件,使用`--export_token`
64+
- 支持将导出的onnx模型转换为mnn模型,默认转换为非对称4bit量化,使用`--export_mnn`
65+
- 指定导出路径使用`--onnx_path``--mnn_path`
6466

6567
## 参数
6668
```
67-
usage: llm_export.py [-h] --path PATH [--type {chatglm-6b,chatglm2-6b,codegeex2-6b,Qwen-7B-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms}]
68-
[--export_path EXPORT_PATH] [--export_verbose] [--export_test] [--test TEST] [--export] [--export_split] [--export_token]
69+
usage: llm_export.py [-h] --path PATH [--type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms}]
70+
[--onnx_path ONNX_PATH] [--mnn_path MNN_PATH] [--export_mnn] [--export_verbose] [--export_test] [--test TEST] [--export] [--export_split] [--export_token]
6971
[--export_embed] [--export_lm] [--export_block EXPORT_BLOCK] [--export_blocks] [--embed_bf16]
7072
7173
LLMExporter
@@ -76,11 +78,13 @@ optional arguments:
7678
Can be either:
7779
- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]
7880
- A path to a *directory* clone from repo like `../chatglm-6b`.
79-
--type {chatglm-6b,chatglm2-6b,codegeex2-6b,Qwen-7B-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms}
81+
--type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms}
8082
type(`str`, *optional*):
8183
The pretrain llm model type.
82-
--export_path EXPORT_PATH
84+
--onnx_path ONNX_PATH
8385
export onnx model path, defaut is `./onnx`.
86+
--mnn_path MNN_PATH export mnn model path, defaut is `./mnn`.
87+
--export_mnn Whether or not to export mnn model after onnx.
8488
--export_verbose Whether or not to export onnx with verbose.
8589
--export_test Whether or not to export onnx with test using onnxruntime.
8690
--test TEST test model inference with query `TEST`.

llm_export.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,33 @@
66
import torch
77
import numpy as np
88
import onnxruntime as ort
9+
import _tools as MNNTools
910
import sentencepiece as spm
1011
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
1112

13+
def onnx2mnn(onnx_path, mnn_dir, quant_bit = 4, asymmetric = True, external_data = False):
14+
model_name, model_extension = os.path.splitext(os.path.basename(onnx_path))
15+
if model_extension != '.onnx':
16+
return
17+
mnn_name = model_name + '.mnn'
18+
mnn_path = os.path.join(mnn_dir, mnn_name)
19+
convert_args = [
20+
'',
21+
'-f',
22+
'ONNX',
23+
'--modelFile',
24+
str(onnx_path),
25+
'--MNNModel',
26+
str(mnn_path),
27+
'--weightQuantBits',
28+
str(quant_bit)
29+
]
30+
if asymmetric:
31+
convert_args.append("--weightQuantAsymmetric")
32+
if external_data:
33+
convert_args.append("--saveExternalData")
34+
MNNTools.mnnconvert(convert_args)
35+
1236
# some wrapper class for export
1337
class Embedding(torch.nn.Module):
1438
def __init__(self, embed, using_bf16: bool = False):
@@ -44,7 +68,13 @@ class LLM(torch.nn.Module):
4468

4569
def __init__(self, args):
4670
super().__init__()
47-
self.export_path = args.export_path
71+
self.onnx_path = args.onnx_path
72+
self.mnn_path = args.mnn_path
73+
if not os.path.exists(self.onnx_path):
74+
os.makedirs(self.onnx_path)
75+
if not os.path.exists(self.mnn_path):
76+
os.makedirs(self.mnn_path)
77+
self.export_mnn = args.export_mnn
4878
self.export_verbose = args.export_verbose
4979
self.export_test = args.export_test
5080
self.embed_bf16 = args.embed_bf16
@@ -134,7 +164,7 @@ def assert_equal(self, torch_outs, onnx_outs):
134164
def export_lm(self):
135165
model = self.lm
136166
hidden_states = torch.randn(1, self.hidden_size)
137-
onnx_model = f'./{self.export_path}/lm.onnx'
167+
onnx_model = f'./{self.onnx_path}/lm.onnx'
138168
torch.onnx.export(model, (hidden_states),
139169
onnx_model,
140170
verbose=self.export_verbose,
@@ -151,11 +181,13 @@ def export_lm(self):
151181
}
152182
onnx_outs = ort_session.run(None, inputs)
153183
self.assert_equal(original_outs, onnx_outs)
184+
if self.export_mnn:
185+
onnx2mnn(onnx_model, self.mnn_path)
154186

155187
def export_embed(self):
156188
model = self.embed
157189
input_ids = torch.arange(3, dtype=torch.long)
158-
onnx_model = f'./{self.export_path}/embedding.onnx'
190+
onnx_model = f'./{self.onnx_path}/embedding.onnx'
159191
torch.onnx.export(model, (input_ids),
160192
onnx_model,
161193
verbose=self.export_verbose,
@@ -175,6 +207,8 @@ def export_embed(self):
175207
}
176208
onnx_outs = ort_session.run(None, inputs)
177209
self.assert_equal(original_outs, onnx_outs)
210+
if self.export_mnn:
211+
onnx2mnn(onnx_model, self.mnn_path)
178212

179213
def export_block(self, block_id: int):
180214
self.seq_len = 3
@@ -184,7 +218,7 @@ def export_block(self, block_id: int):
184218
position_ids = self.get_position_ids()
185219
past_key_values = torch.zeros(self.past_kv_shape[1:])
186220
model = self.blocks[block_id]
187-
onnx_model = f'./{self.export_path}/block_{block_id}.onnx'
221+
onnx_model = f'./{self.onnx_path}/block_{block_id}.onnx'
188222
torch.onnx.export(
189223
model, (inputs_embeds, attention_mask, position_ids, past_key_values),
190224
onnx_model,
@@ -207,6 +241,8 @@ def export_block(self, block_id: int):
207241
}
208242
onnx_outs = ort_session.run(None, inputs)
209243
self.assert_equal(original_outs, onnx_outs)
244+
if self.export_mnn:
245+
onnx2mnn(onnx_model, self.mnn_path)
210246

211247
def export_blocks(self):
212248
for i in range(self.block_nums):
@@ -220,7 +256,7 @@ def export(self):
220256
attention_mask = self.get_attention_mask()
221257
position_ids = self.get_position_ids()
222258
past_key_values = torch.zeros(self.past_kv_shape)
223-
onnx_model = f'./{self.export_path}/llm.onnx'
259+
onnx_model = f'./{self.onnx_path}/llm.onnx'
224260
torch.onnx.export(
225261
model, (input_ids, attention_mask, position_ids, past_key_values),
226262
onnx_model,
@@ -244,9 +280,12 @@ def export(self):
244280
}
245281
onnx_outs = ort_session.run(None, inputs)
246282
self.assert_equal(original_outs, onnx_outs)
283+
if self.export_mnn:
284+
# single model is > 2G, using external_data
285+
onnx2mnn(onnx_model, self.mnn_path, 4, True, True)
247286

248287
def export_tokenizer(self):
249-
file_path = os.path.join(self.export_path, "tokenizer.txt")
288+
file_path = os.path.join(self.onnx_path, "tokenizer.txt")
250289
if self.sp_model is not None:
251290
# senetencepiece
252291
NORMAL = 1; UNKNOWN = 2; CONTROL = 3
@@ -644,7 +683,9 @@ def get_position_ids(self) -> torch.Tensor:
644683
help='type(`str`, *optional*):'
645684
'\n\tThe pretrain llm model type.'
646685
)
647-
parser.add_argument('--export_path', type=str, default='./onnx', help='export onnx model path, defaut is `./onnx`.')
686+
parser.add_argument('--onnx_path', type=str, default='./onnx', help='export onnx model path, defaut is `./onnx`.')
687+
parser.add_argument('--mnn_path', type=str, default='./mnn', help='export mnn model path, defaut is `./mnn`.')
688+
parser.add_argument('--export_mnn', action='store_true', default=False, help='Whether or not to export mnn model after onnx.')
648689
parser.add_argument('--export_verbose', action='store_true', default=False, help='Whether or not to export onnx with verbose.')
649690
parser.add_argument('--export_test', action='store_true', help='Whether or not to export onnx with test using onnxruntime.')
650691
parser.add_argument('--test', type=str, help='test model inference with query `TEST`.')

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
MNN==2.8.0
12
numpy==1.25.2
23
onnxruntime==1.15.1
34
torch==2.0.1

0 commit comments

Comments
 (0)