Skip to content

Commit 34aa0a8

Browse files
committed
support qwen-1.5-4b and embed_bin.
1 parent 185fcab commit 34aa0a8

File tree

4 files changed

+1756
-7
lines changed

4 files changed

+1756
-7
lines changed

llm_export.py

Lines changed: 169 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
import numpy as np
88
from onnxslim import slim
99
import onnxruntime as ort
10-
import _tools as MNNTools
1110
import sentencepiece as spm
1211
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
12+
try:
13+
import _tools as MNNTools
14+
except:
15+
MNNTools = None
1316

1417
def onnx2mnn(onnx_path, mnn_dir, quant_bit = 4, asymmetric = True, external_data = False, bizCode : str= None):
1518
model_name, model_extension = os.path.splitext(os.path.basename(onnx_path))
@@ -83,13 +86,25 @@ def __init__(self, args):
8386
self.export_mnn = args.export_mnn
8487
self.export_verbose = args.export_verbose
8588
self.export_test = args.export_test
86-
self.embed_bf16 = args.embed_bf16
89+
# default is False, just set True when using below command:
90+
# `python llm_export ../path --export --embed_bin` to export single model without embedding
91+
self.without_embed = False
92+
self.embed_bin = args.embed_bin
93+
if self.embed_bin:
94+
self.embed_bf16 = True
95+
else:
96+
self.embed_bf16 = args.embed_bf16
8797
self.skip_slim = args.skip_slim
8898
tokenizer_model = os.path.join(args.path, 'tokenizer.model')
8999
if os.path.exists(tokenizer_model):
90100
self.sp_model = spm.SentencePieceProcessor(tokenizer_model)
91101
else:
92102
self.sp_model = None
103+
merge_file = os.path.join(args.path, 'merges.txt')
104+
if os.path.exists(merge_file):
105+
self.merge_txt = merge_file
106+
else:
107+
self.merge_txt = None
93108
self.stop_ids = []
94109
self.max_length = 1024
95110
self.hidden_size = 4096
@@ -111,11 +126,14 @@ def export_vocab(self):
111126
def visual_embed(self, input_ids):
112127
raise NotImplementedError
113128

114-
def forward(self, input_ids, attention_mask, position_ids, past_key_values):
115-
if self.visual is not None and past_key_values[0] is None:
116-
hidden_states = self.visual_embed(input_ids)
129+
def __embedding(self, input_ids):
130+
if self.visual is not None and self.token_len == 0:
131+
input_embeds = self.visual_embed(input_ids)
117132
else:
118-
hidden_states = self.embed(input_ids)
133+
input_embeds = self.embed(input_ids)
134+
return input_embeds
135+
136+
def __decode(self, hidden_states, attention_mask, position_ids, past_key_values):
119137
presents = []
120138
for i in range(self.block_nums):
121139
hidden_states, kv = self.blocks[i](hidden_states, attention_mask, position_ids, past_key_values[i])
@@ -126,6 +144,11 @@ def forward(self, input_ids, attention_mask, position_ids, past_key_values):
126144
self.token_len += 1
127145
return token_id, presents
128146

147+
def forward(self, input_ids, attention_mask, position_ids, past_key_values):
148+
if self.without_embed:
149+
return self.__decode(input_ids, attention_mask, position_ids, past_key_values)
150+
return self.__decode(self.__embedding(input_ids), attention_mask, position_ids, past_key_values)
151+
129152
# some test functions
130153
def build_prompt(self, query):
131154
if hasattr(self.tokenizer, 'build_prompt'):
@@ -233,6 +256,14 @@ def export_visual(self):
233256

234257
def export_embed(self):
235258
model = self.embed
259+
if self.embed_bin:
260+
import ctypes
261+
tensor_data = model.embed.weight.data
262+
data_ptr = tensor_data.untyped_storage().data_ptr()
263+
buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr)
264+
with open(f'./{self.mnn_path}/embeddings_bf16.bin', 'wb') as f:
265+
f.write(buffer)
266+
return
236267
input_ids = torch.arange(3, dtype=torch.long)
237268
onnx_model = f'./{self.onnx_path}/embedding.onnx'
238269
torch.onnx.export(model, (input_ids),
@@ -308,6 +339,10 @@ def export(self):
308339
position_ids = self.get_position_ids()
309340
past_key_values = torch.zeros(self.past_kv_shape)
310341
onnx_model = f'./{self.onnx_path}/llm.onnx'
342+
if self.embed_bin:
343+
self.without_embed = True
344+
input_ids = self.__embedding(input_ids)
345+
print('export start ...')
311346
torch.onnx.export(
312347
model, (input_ids, attention_mask, position_ids, past_key_values),
313348
onnx_model,
@@ -319,6 +354,7 @@ def export(self):
319354
dynamic_axes=self.model_dynamic_axes,
320355
do_constant_folding=True,
321356
opset_version=15)
357+
print('export done!')
322358
if not self.skip_slim:
323359
slim(onnx_model, output_model=onnx_model)
324360
if self.export_test:
@@ -336,11 +372,14 @@ def export(self):
336372
if self.export_mnn:
337373
# single model is > 2G, using external_data
338374
onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric, True)
375+
if self.without_embed:
376+
self.without_embed = False
339377

340378
def export_tokenizer(self):
341379
file_path = os.path.join(self.onnx_path, "tokenizer.txt")
342380
if self.sp_model is not None:
343381
# senetencepiece
382+
print('# senetencepiece tokenier')
344383
NORMAL = 1; UNKNOWN = 2; CONTROL = 3
345384
USER_DEFINED = 4; UNUSED = 5; BYTE = 6
346385
fp = open(file_path, "w", encoding="utf8")
@@ -365,6 +404,7 @@ def export_tokenizer(self):
365404
fp.write(f'{token_encode} {score} {type}\n')
366405
fp.close()
367406
elif hasattr(self.tokenizer, 'mergeable_ranks'):
407+
print('# tiktoken tokenier')
368408
# tikton
369409
with open(file_path, "w", encoding="utf8") as fp:
370410
for k, v in self.tokenizer.mergeable_ranks.items():
@@ -374,6 +414,25 @@ def export_tokenizer(self):
374414
for k, v in self.tokenizer.special_tokens.items():
375415
line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n"
376416
fp.write(line)
417+
elif self.merge_txt is not None:
418+
# huggingface tokenizer
419+
merge_list = []
420+
vocab = self.tokenizer.get_vocab()
421+
vocab_list = ['<unk>' for i in range(len(vocab))]
422+
# load vocab
423+
for k, v in vocab.items():
424+
vocab_list[int(v)] = k
425+
# load merge
426+
with open(self.merge_txt, 'rt') as merge:
427+
for line in merge.readlines():
428+
merge_list.append(line)
429+
# write to tokenizer.txt
430+
with open(file_path, "w", encoding="utf8") as fp:
431+
fp.write(f'{len(vocab_list)} {len(merge_list)}\n')
432+
for v in vocab_list:
433+
fp.write(v + '\n')
434+
for m in merge_list:
435+
fp.write(m)
377436
else:
378437
# huggingface tokenizer
379438
def unicode_to_byte(u: int):
@@ -706,6 +765,107 @@ def visual_embed(self, input_ids):
706765
hidden_states[i][a + 1 : b] = images[idx]
707766
return hidden_states.view(-1, 1, self.hidden_size)
708767

768+
class QWEN2Block(torch.nn.Module):
769+
def __init__(self, name, block, block_id, hidden_size, final_layernorm = None):
770+
super().__init__()
771+
self.name = name
772+
self.block = block
773+
self.block_id = block_id
774+
self.final_layernorm = final_layernorm
775+
self.hidden_size = hidden_size
776+
777+
def forward(self, hidden_states, attention_mask, position_ids, past_kv):
778+
theta = 1.0 / (10000.0 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128))
779+
position_ids = position_ids.float().reshape(-1, 1)
780+
idx_theta = position_ids * theta
781+
rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1)
782+
rotary_pos_emb = rotary_pos_emb.unsqueeze(0).unsqueeze(0)
783+
rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)])
784+
hidden_states = hidden_states.view(1, -1, self.hidden_size)
785+
hidden_states, presents = self.block(hidden_states=hidden_states,
786+
attention_mask=attention_mask,
787+
past_key_value=past_kv,
788+
rotary_pos_emb=rotary_pos_emb,
789+
use_cache=True)
790+
if self.final_layernorm is not None:
791+
hidden_states = self.final_layernorm(hidden_states)
792+
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
793+
if isinstance(presents, tuple):
794+
presents = torch.stack(presents)
795+
return hidden_states, presents
796+
797+
class Qwen2_Chat(LLM):
798+
def __init__(self, args):
799+
super().__init__(args)
800+
801+
def load_model(self, model_path: str):
802+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
803+
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval()
804+
# Qwen2 models
805+
self.model_name = 'Qwen2-7B'
806+
transformer = model.model
807+
self.lm_ = model.lm_head
808+
self.embed_ = transformer.embed_tokens
809+
self.blocks_ = transformer.layers
810+
self.final_layernorm_ = transformer.norm
811+
# some wrapper
812+
self.stop_id = self.tokenizer.eos_token_id
813+
if hasattr(model, 'generation_config'):
814+
self.stop_ids.append(self.stop_id)
815+
for id in model.generation_config.eos_token_id:
816+
self.stop_ids.append(id)
817+
self.block_nums = len(self.blocks_)
818+
self.hidden_size = self.embed_.weight.shape[-1]
819+
self.embed = Embedding(self.embed_, self.embed_bf16)
820+
self.lm = Lm(self.lm_)
821+
self.blocks = [QWEN2Block(self.model_name, self.blocks_[i], i, self.hidden_size, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
822+
# 4b
823+
self.past_kv_shape = [self.block_nums, 2, 1, 20, 0, 128]
824+
# some config for export
825+
self.block_dynamic_axes = {
826+
"inputs_embeds" : { 0: "seq_len" },
827+
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
828+
"position_ids" : { 0: "seq_len" },
829+
"past_key_values" : { 2: "history_len" }
830+
}
831+
self.model_dynamic_axes = {
832+
"input_ids" : { 0: "seq_len" },
833+
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
834+
"position_ids" : { 0: "seq_len" },
835+
"past_key_values" : { 3: "history_len" }
836+
}
837+
838+
def build_prompt(self, query):
839+
return f'<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n'
840+
841+
def get_attention_mask(self) -> torch.Tensor:
842+
if self.token_len:
843+
return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
844+
return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
845+
846+
847+
def get_position_ids(self) -> torch.Tensor:
848+
if self.token_len:
849+
return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
850+
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
851+
852+
def visual_embed(self, input_ids):
853+
if not torch.any(input_ids == self.image_start_id):
854+
return self.embed(input_ids)
855+
bos_pos = torch.where(input_ids == self.image_start_id)
856+
eos_pos = torch.where(input_ids == self.image_start_id + 1)
857+
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
858+
images = []
859+
for i, a, b in img_pos:
860+
image = input_ids[i][a + 1 : b - 1].tolist()
861+
image = image[ : image.index(self.image_start_id + 2)]
862+
images.append(bytes(image).decode('utf-8'))
863+
images = self.visual.encode(images)
864+
hidden_states = self.embed(input_ids).view(1, -1, self.hidden_size)
865+
for idx, (i, a, b) in enumerate(img_pos):
866+
hidden_states[i][a + 1 : b] = images[idx]
867+
return hidden_states.view(-1, 1, self.hidden_size)
868+
709869
# llama2
710870
class LLAMA2Block(torch.nn.Module):
711871
def __init__(self, block, block_id, hidden_size, final_layernorm = None):
@@ -1021,6 +1181,7 @@ def export(self):
10211181
'Qwen-7B-Chat': Qwen_Chat,
10221182
'Qwen-1_8B-Chat': Qwen_Chat,
10231183
'Qwen-VL-Chat': Qwen_Chat,
1184+
'Qwen1_5-4B-Chat': Qwen2_Chat,
10241185
'Baichuan2-7B-Chat': Llama2_7b_Chat,
10251186
'Llama-2-7b-chat-ms': Llama2_7b_Chat,
10261187
'internlm-chat-7b': Llama2_7b_Chat,
@@ -1059,6 +1220,7 @@ def export(self):
10591220
parser.add_argument('--export_lm', action='store_true', help='export llm lm_head to an `onnx` model.')
10601221
parser.add_argument('--export_block', type=int, help='export llm block [id] to an `onnx` model.')
10611222
parser.add_argument('--export_blocks', action='store_true', help='export llm all blocks to `onnx` models.')
1223+
parser.add_argument('--embed_bin', action='store_true', help='export embedding weight as bin file with dtype `bfloat16`')
10621224
parser.add_argument('--embed_bf16', action='store_true', help='using `bfloat16` replace `float32` in embedding.')
10631225
parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.')
10641226

@@ -1103,4 +1265,4 @@ def export(self):
11031265
llm_exporter.export_blocks()
11041266

11051267
if args.export_block is not None:
1106-
llm_exporter.export_block(args.export_block)
1268+
llm_exporter.export_block(args.export_block)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"architectures": [
3+
"Qwen2ForCausalLM"
4+
],
5+
"auto_map": {
6+
"AutoConfig": "configuration_qwen2.Qwen2Config",
7+
"AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM"
8+
},
9+
"attention_dropout": 0.0,
10+
"bos_token_id": 151643,
11+
"eos_token_id": 151645,
12+
"hidden_act": "silu",
13+
"hidden_size": 2560,
14+
"initializer_range": 0.02,
15+
"intermediate_size": 6912,
16+
"max_position_embeddings": 32768,
17+
"max_window_layers": 21,
18+
"model_type": "qwen2",
19+
"num_attention_heads": 20,
20+
"num_hidden_layers": 40,
21+
"num_key_value_heads": 20,
22+
"rms_norm_eps": 1e-06,
23+
"rope_theta": 5000000.0,
24+
"sliding_window": 32768,
25+
"tie_word_embeddings": false,
26+
"torch_dtype": "bfloat16",
27+
"transformers_version": "4.37.0",
28+
"use_cache": true,
29+
"use_sliding_window": false,
30+
"vocab_size": 151936
31+
}

0 commit comments

Comments
 (0)