Skip to content

Commit 29fd8dd

Browse files
committed
support qwen-1.8b export
1 parent f331381 commit 29fd8dd

File tree

3 files changed

+1458
-11
lines changed

3 files changed

+1458
-11
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ llm-export是一个llm模型导出工具,能够将llm模型导出到onnx模型
1515
- [![Download][download-qwen-7b-chat-onnx]][release-qwen-7b-chat-onnx]
1616
- [![Download][download-baichuan2-7b-chat-onnx]][release-baichuan2-7b-chat-onnx]
1717
- [![Download][download-llama2-7b-chat-onnx]][release-llama2-7b-chat-onnx]
18+
- [![Download][download-qwen-1.8b-chat-onnx]][release-qwen-1.8b-chat-onnx]
1819

1920
[download-chatglm-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm-6b-onnx/total
2021
[download-chatglm2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm2-6b-onnx/total
@@ -23,13 +24,15 @@ llm-export是一个llm模型导出工具,能够将llm模型导出到onnx模型
2324
[download-qwen-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen-7b-chat-onnx/total
2425
[download-baichuan2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/baichuan2-7b-chat-onnx/total
2526
[download-llama2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/llama2-7b-chat-onnx/total
27+
[download-qwen-1.8b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen-1.8b-onnx/total
2628
[release-chatglm-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm-6b-onnx
2729
[release-chatglm2-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm2-6b-onnx
2830
[release-chatglm3-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm3-6b-onnx
2931
[release-codegeex2-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/codegeex2-6b-onnx
3032
[release-qwen-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen-7b-chat-onnx
3133
[release-baichuan2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/baichuan2-7b-chat-onnx
3234
[release-llama2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/llama2-7b-chat-onnx
35+
[release-qwen-1.8b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen-1.8b-onnx
3336

3437
## 用法
3538
1. 将该项目clone到本地

llm_export.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class Embedding(torch.nn.Module):
1414
def __init__(self, embed, using_bf16: bool = False):
1515
super().__init__()
1616
self.bf16 = using_bf16
17+
self.embed_dim = embed.weight.shape[-1]
1718
if using_bf16:
1819
# using bf16 embedding weight
1920
self.embed = embed.bfloat16()
@@ -24,7 +25,7 @@ def forward(self, input_ids):
2425
res = self.embed(input_ids)
2526
if self.bf16:
2627
res = res.float()
27-
return res.view(-1, 1, 4096)
28+
return res.view(-1, 1, self.embed_dim)
2829

2930
class Lm(torch.nn.Module):
3031
def __init__(self, lm):
@@ -52,8 +53,9 @@ def __init__(self, args):
5253
self.sp_model = spm.SentencePieceProcessor(tokenizer_model)
5354
else:
5455
self.sp_model = None
55-
self.load_model(args.path)
5656
self.max_length = 1024
57+
self.hidden_size = 4096
58+
self.load_model(args.path)
5759

5860
def load_model(self, model_path: str):
5961
raise NotImplementedError
@@ -131,7 +133,7 @@ def assert_equal(self, torch_outs, onnx_outs):
131133

132134
def export_lm(self):
133135
model = self.lm
134-
hidden_states = torch.randn(1, 4096)
136+
hidden_states = torch.randn(1, self.hidden_size)
135137
onnx_model = f'./{self.export_path}/lm.onnx'
136138
torch.onnx.export(model, (hidden_states),
137139
onnx_model,
@@ -177,7 +179,7 @@ def export_embed(self):
177179
def export_block(self, block_id: int):
178180
self.seq_len = 3
179181
self.token_len = 0
180-
inputs_embeds = torch.randn((self.seq_len, 1, 4096))
182+
inputs_embeds = torch.randn((self.seq_len, 1, self.hidden_size))
181183
attention_mask = self.get_attention_mask()
182184
position_ids = self.get_position_ids()
183185
past_key_values = torch.zeros(self.past_kv_shape[1:])
@@ -294,7 +296,7 @@ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
294296
use_cache=True)
295297
if self.final_layernorm is not None:
296298
hidden_states = self.final_layernorm(hidden_states)
297-
hidden_states = hidden_states.view(-1, 4096)[-1].view(1, 1, 4096)
299+
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
298300
if isinstance(presents, tuple):
299301
presents = torch.stack(presents)
300302
return hidden_states, presents
@@ -376,7 +378,7 @@ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
376378
rotary_pos_emb=rotary_pos_emb)
377379
if self.final_layernorm is not None:
378380
hidden_states = self.final_layernorm(hidden_states)
379-
hidden_states = hidden_states.view(-1, 4096)[-1].view(1, 1, 4096)
381+
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
380382
if isinstance(presents, tuple):
381383
presents = torch.stack(presents)
382384
return hidden_states, presents
@@ -442,27 +444,55 @@ def build_prompt(self, query):
442444

443445
# qwen
444446
class QWENBlock(torch.nn.Module):
445-
def __init__(self, block, block_id, final_layernorm = None):
447+
def __init__(self, block, block_id, hidden_size, final_layernorm = None):
446448
super().__init__()
447449
self.block = block
448450
self.block_id = block_id
449451
self.final_layernorm = final_layernorm
452+
self.hidden_size = hidden_size
450453

451454
def forward(self, hidden_states, attention_mask, position_ids, past_kv):
452455
theta = 1.0 / (10000.0 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128))
453456
position_ids = position_ids.float().reshape(-1, 1)
454457
idx_theta = position_ids * theta
455458
rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1)
456459
rotary_pos_emb = rotary_pos_emb.unsqueeze(1).unsqueeze(0)
457-
hidden_states = hidden_states.view(1, -1, 4096)
460+
hidden_states = hidden_states.view(1, -1, self.hidden_size)
458461
hidden_states, presents = self.block(hidden_states,
459462
past_kv,
460463
attention_mask,
461464
rotary_pos_emb,
462465
use_cache=True)
463466
if self.final_layernorm is not None:
464467
hidden_states = self.final_layernorm(hidden_states)
465-
hidden_states = hidden_states.view(-1, 4096)[-1].view(1, 1, 4096)
468+
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
469+
if isinstance(presents, tuple):
470+
presents = torch.stack(presents)
471+
return hidden_states, presents
472+
473+
class QWEN18Block(torch.nn.Module):
474+
def __init__(self, block, block_id, hidden_size, final_layernorm = None):
475+
super().__init__()
476+
self.block = block
477+
self.block_id = block_id
478+
self.final_layernorm = final_layernorm
479+
self.hidden_size = hidden_size
480+
481+
def forward(self, hidden_states, attention_mask, position_ids, past_kv):
482+
theta = 1.0 / (10000.0 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128))
483+
position_ids = position_ids.float().reshape(-1, 1)
484+
idx_theta = position_ids * theta
485+
rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1).unsqueeze(1).unsqueeze(0)
486+
rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)])
487+
hidden_states = hidden_states.view(1, -1, self.hidden_size)
488+
hidden_states, presents = self.block(hidden_states,
489+
rotary_pos_emb,
490+
past_kv,
491+
attention_mask,
492+
use_cache=True)
493+
if self.final_layernorm is not None:
494+
hidden_states = self.final_layernorm(hidden_states)
495+
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
466496
if isinstance(presents, tuple):
467497
presents = torch.stack(presents)
468498
return hidden_states, presents
@@ -482,11 +512,18 @@ def load_model(self, model_path: str):
482512
# some wrapper
483513
self.stop_id = self.tokenizer.im_end_id
484514
self.block_nums = len(self.blocks_)
515+
self.hidden_size = transformer.embed_dim
485516
self.embed = Embedding(self.embed_, self.embed_bf16)
486517
self.lm = Lm(self.lm_)
487-
self.blocks = [QWENBlock(self.blocks_[i], i, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
518+
if self.block_nums == 32:
519+
# qwen-7b
520+
self.blocks = [QWENBlock(self.blocks_[i], i, self.hidden_size, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
521+
self.past_kv_shape = [32, 2, 1, 0, 32, 128]
522+
elif self.block_nums == 24:
523+
# qwen-1.8b
524+
self.blocks = [QWEN18Block(self.blocks_[i], i, self.hidden_size, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
525+
self.past_kv_shape = [24, 2, 1, 0, 16, 128]
488526
# some config for export
489-
self.past_kv_shape = [32, 2, 1, 0, 32, 128]
490527
self.block_dynamic_axes = {
491528
"inputs_embeds" : { 0: "seq_len" },
492529
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
@@ -594,6 +631,7 @@ def get_position_ids(self) -> torch.Tensor:
594631
'chatglm3-6b': Chatglm3_6b,
595632
'codegeex2-6b': Chatglm2_6b,
596633
'Qwen-7B-Chat': Qwen_7b_Chat,
634+
'Qwen-1_8B-Chat': Qwen_7b_Chat,
597635
'Baichuan2-7B-Chat': Llama2_7b_Chat,
598636
'Llama-2-7b-chat-ms': Llama2_7b_Chat
599637
}

0 commit comments

Comments
 (0)