Skip to content

Commit 0abf711

Browse files
committed
support phi-2
1 parent bfe22b2 commit 0abf711

File tree

3 files changed

+1085
-4
lines changed

3 files changed

+1085
-4
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ llm-export是一个llm模型导出工具,能够将llm模型导出到onnx模型
1616
- [![Download][download-baichuan2-7b-chat-onnx]][release-baichuan2-7b-chat-onnx]
1717
- [![Download][download-llama2-7b-chat-onnx]][release-llama2-7b-chat-onnx]
1818
- [![Download][download-qwen-1.8b-chat-onnx]][release-qwen-1.8b-chat-onnx]
19+
- [![Download][download-phi-2-onnx]][release-phi-2-onnx]
1920

2021
[download-chatglm-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm-6b-onnx/total
2122
[download-chatglm2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm2-6b-onnx/total
@@ -25,6 +26,7 @@ llm-export是一个llm模型导出工具,能够将llm模型导出到onnx模型
2526
[download-baichuan2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/baichuan2-7b-chat-onnx/total
2627
[download-llama2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/llama2-7b-chat-onnx/total
2728
[download-qwen-1.8b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen-1.8b-onnx/total
29+
[download-phi-2-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/phi-2-onnx/total
2830
[release-chatglm-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm-6b-onnx
2931
[release-chatglm2-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm2-6b-onnx
3032
[release-chatglm3-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm3-6b-onnx
@@ -33,6 +35,7 @@ llm-export是一个llm模型导出工具,能够将llm模型导出到onnx模型
3335
[release-baichuan2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/baichuan2-7b-chat-onnx
3436
[release-llama2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/llama2-7b-chat-onnx
3537
[release-qwen-1.8b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen-1.8b-onnx
38+
[release-phi-2-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/phi-2-onnx
3639

3740
## 用法
3841
1. 将该项目clone到本地

llm_export.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class LLM(torch.nn.Module):
6868

6969
def __init__(self, args):
7070
super().__init__()
71+
self.quant_bit = 4
72+
self.asymmetric = True
7173
self.onnx_path = args.onnx_path
7274
self.mnn_path = args.mnn_path
7375
if not os.path.exists(self.onnx_path):
@@ -182,7 +184,7 @@ def export_lm(self):
182184
onnx_outs = ort_session.run(None, inputs)
183185
self.assert_equal(original_outs, onnx_outs)
184186
if self.export_mnn:
185-
onnx2mnn(onnx_model, self.mnn_path)
187+
onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric)
186188

187189
def export_embed(self):
188190
model = self.embed
@@ -311,12 +313,25 @@ def export_tokenizer(self):
311313
token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8")
312314
fp.write(f'{token_encode} {score} {type}\n')
313315
fp.close()
314-
else:
316+
elif hasattr(self.tokenizer, 'mergeable_ranks'):
315317
# tikton
316318
with open(file_path, "w", encoding="utf8") as fp:
317319
for k, v in self.tokenizer.mergeable_ranks.items():
318320
line = base64.b64encode(k).decode("utf8") + "\n"
319321
fp.write(line)
322+
else:
323+
# other
324+
with open(file_path, "w", encoding="utf8") as fp:
325+
vocab = self.tokenizer.get_vocab()
326+
vocab_list = ['<unk>' for i in range(len(vocab))]
327+
for k, v in vocab.items():
328+
k = k.replace('Ċ', '\n')
329+
k = k.replace('Ġ', ' ')
330+
vocab_list[int(v)] = k
331+
for v in vocab_list:
332+
line = base64.b64encode(v.encode('utf-8')).decode("utf8") + "\n"
333+
fp.write(line)
334+
320335

321336
# chatglm
322337
class GLMBlock(torch.nn.Module):
@@ -405,6 +420,7 @@ def __init__(self, block, block_id, final_layernorm = None):
405420
self.block = block
406421
self.block_id = block_id
407422
self.final_layernorm = final_layernorm
423+
self.hidden_size = 4096
408424

409425
def forward(self, hidden_states, attention_mask, position_ids, past_kv):
410426
theta = 1.0 / (10000 ** (torch.arange(0, 64, 2, dtype=torch.float32) / 64))
@@ -447,7 +463,7 @@ def load_model(self, model_path: str):
447463
self.lm = Lm(self.lm_)
448464
self.blocks = [GLM2Block(self.blocks_[i], i, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
449465
# some config for export
450-
self.past_kv_shape = [28, 2, 0, 1, 2, 128]
466+
self.past_kv_shape = [len(self.blocks), 1, 0, 2, 32, 80]
451467
self.block_dynamic_axes = {
452468
"inputs_embeds" : { 0: "seq_len" },
453469
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
@@ -663,6 +679,78 @@ def get_position_ids(self) -> torch.Tensor:
663679
return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
664680
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
665681

682+
class PHI2Block(torch.nn.Module):
683+
def __init__(self, block, block_id, hidden_size):
684+
super().__init__()
685+
self.block = block
686+
self.block_id = block_id
687+
self.hidden_size = hidden_size
688+
689+
def forward(self, hidden_states, attention_mask, position_ids, past_kv):
690+
theta = 1.0 / (10000 ** (torch.arange(0, 32, 2, dtype=torch.float32) / 32))
691+
position_ids = position_ids.float().reshape(-1, 1)
692+
idx_theta = position_ids * theta
693+
rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=0).contiguous()
694+
hidden_states = hidden_states.view(1, -1, self.hidden_size)
695+
hidden_states, presents = self.block(hidden_states,
696+
past_kv,
697+
rotary_pos_emb=rotary_pos_emb,
698+
causal_mask=attention_mask
699+
)
700+
if self.block_id == 31:
701+
hidden_states = hidden_states[:, -1, :]
702+
return hidden_states, presents
703+
704+
class phi_2(LLM):
705+
def __init__(self, args):
706+
super().__init__(args)
707+
self.model_name = 'phi-2'
708+
self.asymmetric = False # TODO: some precision bug when using asymmetric
709+
710+
def load_model(self, model_path: str):
711+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
712+
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval()
713+
transformer = model.transformer
714+
self.lm_ = model.lm_head
715+
self.embed_ = transformer.embd.wte
716+
self.hidden_size = self.embed_.weight.shape[-1]
717+
self.blocks_ = transformer.h
718+
# self.final_layernorm_ = transformer.final_layernorm
719+
# some wrapper
720+
self.stop_id = self.tokenizer.eos_token_id
721+
self.block_nums = len(self.blocks_)
722+
self.embed = Embedding(self.embed_, self.embed_bf16)
723+
self.lm = Lm(self.lm_)
724+
self.blocks = [PHI2Block(self.blocks_[i], i, self.hidden_size) for i in range(self.block_nums)]
725+
# some config for export
726+
self.past_kv_shape = [len(self.blocks), 1, 0, 2, 32, 80]
727+
self.block_dynamic_axes = {
728+
"inputs_embeds" : { 0: "seq_len" },
729+
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
730+
"position_ids" : { 0: "seq_len" },
731+
"past_key_values" : { 1: "history_len" }
732+
}
733+
self.model_dynamic_axes = {
734+
"input_ids" : { 0: "seq_len" },
735+
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
736+
"position_ids" : { 0: "seq_len" },
737+
"past_key_values" : { 2: "history_len" }
738+
}
739+
740+
def build_prompt(self, query):
741+
return f'Instruct: {query}\nOutput:'
742+
743+
def get_attention_mask(self) -> torch.Tensor:
744+
if self.token_len:
745+
return torch.zeros([1, 1, 1, 1]).bool()
746+
attention_mask = ~torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool())
747+
return attention_mask
748+
749+
def get_position_ids(self) -> torch.Tensor:
750+
if self.token_len:
751+
return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
752+
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
753+
666754
if __name__ == '__main__':
667755
llm_models = {
668756
'chatglm-6b': Chatglm_6b,
@@ -672,7 +760,8 @@ def get_position_ids(self) -> torch.Tensor:
672760
'Qwen-7B-Chat': Qwen_7b_Chat,
673761
'Qwen-1_8B-Chat': Qwen_7b_Chat,
674762
'Baichuan2-7B-Chat': Llama2_7b_Chat,
675-
'Llama-2-7b-chat-ms': Llama2_7b_Chat
763+
'Llama-2-7b-chat-ms': Llama2_7b_Chat,
764+
'phi-2': phi_2
676765
}
677766
parser = argparse.ArgumentParser(description='LLMExporter', formatter_class=argparse.RawTextHelpFormatter)
678767
parser.add_argument('--path', type=str, default='THUDM/chatglm-6b', required=True,

0 commit comments

Comments
 (0)