Skip to content

Commit 7a9d29f

Browse files
committed
support phi-2
1 parent bfe22b2 commit 7a9d29f

File tree

2 files changed

+1078
-3
lines changed

2 files changed

+1078
-3
lines changed

llm_export.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,25 @@ def export_tokenizer(self):
311311
token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8")
312312
fp.write(f'{token_encode} {score} {type}\n')
313313
fp.close()
314-
else:
314+
elif hasattr(self.tokenizer, 'mergeable_ranks'):
315315
# tikton
316316
with open(file_path, "w", encoding="utf8") as fp:
317317
for k, v in self.tokenizer.mergeable_ranks.items():
318318
line = base64.b64encode(k).decode("utf8") + "\n"
319319
fp.write(line)
320+
else:
321+
# other
322+
with open(file_path, "w", encoding="utf8") as fp:
323+
vocab = self.tokenizer.get_vocab()
324+
vocab_list = ['<unk>' for i in range(len(vocab))]
325+
for k, v in vocab.items():
326+
k = k.replace('Ċ', '\n')
327+
k = k.replace('Ġ', ' ')
328+
vocab_list[int(v)] = k
329+
for v in vocab_list:
330+
line = base64.b64encode(v.encode('utf-8')).decode("utf8") + "\n"
331+
fp.write(line)
332+
320333

321334
# chatglm
322335
class GLMBlock(torch.nn.Module):
@@ -405,6 +418,7 @@ def __init__(self, block, block_id, final_layernorm = None):
405418
self.block = block
406419
self.block_id = block_id
407420
self.final_layernorm = final_layernorm
421+
self.hidden_size = 4096
408422

409423
def forward(self, hidden_states, attention_mask, position_ids, past_kv):
410424
theta = 1.0 / (10000 ** (torch.arange(0, 64, 2, dtype=torch.float32) / 64))
@@ -447,7 +461,7 @@ def load_model(self, model_path: str):
447461
self.lm = Lm(self.lm_)
448462
self.blocks = [GLM2Block(self.blocks_[i], i, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
449463
# some config for export
450-
self.past_kv_shape = [28, 2, 0, 1, 2, 128]
464+
self.past_kv_shape = [len(self.blocks), 1, 0, 2, 32, 80]
451465
self.block_dynamic_axes = {
452466
"inputs_embeds" : { 0: "seq_len" },
453467
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
@@ -663,6 +677,77 @@ def get_position_ids(self) -> torch.Tensor:
663677
return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
664678
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
665679

680+
class PHI2Block(torch.nn.Module):
681+
def __init__(self, block, block_id, hidden_size):
682+
super().__init__()
683+
self.block = block
684+
self.block_id = block_id
685+
self.hidden_size = hidden_size
686+
687+
def forward(self, hidden_states, attention_mask, position_ids, past_kv):
688+
theta = 1.0 / (10000 ** (torch.arange(0, 32, 2, dtype=torch.float32) / 32))
689+
position_ids = position_ids.float().reshape(-1, 1)
690+
idx_theta = position_ids * theta
691+
rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=0).contiguous()
692+
hidden_states = hidden_states.view(1, -1, self.hidden_size)
693+
hidden_states, presents = self.block(hidden_states,
694+
past_kv,
695+
rotary_pos_emb=rotary_pos_emb,
696+
causal_mask=attention_mask
697+
)
698+
if self.block_id == 31:
699+
hidden_states = hidden_states[:, -1, :]
700+
return hidden_states, presents
701+
702+
class phi_2(LLM):
703+
def __init__(self, args):
704+
super().__init__(args)
705+
self.model_name = 'phi-2'
706+
707+
def load_model(self, model_path: str):
708+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
709+
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval()
710+
transformer = model.transformer
711+
self.lm_ = model.lm_head
712+
self.embed_ = transformer.embd.wte
713+
self.hidden_size = self.embed_.weight.shape[-1]
714+
self.blocks_ = transformer.h
715+
# self.final_layernorm_ = transformer.final_layernorm
716+
# some wrapper
717+
self.stop_id = self.tokenizer.eos_token_id
718+
self.block_nums = len(self.blocks_)
719+
self.embed = Embedding(self.embed_, self.embed_bf16)
720+
self.lm = Lm(self.lm_)
721+
self.blocks = [PHI2Block(self.blocks_[i], i, self.hidden_size) for i in range(self.block_nums)]
722+
# some config for export
723+
self.past_kv_shape = [len(self.blocks), 1, 0, 2, 32, 80]
724+
self.block_dynamic_axes = {
725+
"inputs_embeds" : { 0: "seq_len" },
726+
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
727+
"position_ids" : { 0: "seq_len" },
728+
"past_key_values" : { 1: "history_len" }
729+
}
730+
self.model_dynamic_axes = {
731+
"input_ids" : { 0: "seq_len" },
732+
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
733+
"position_ids" : { 0: "seq_len" },
734+
"past_key_values" : { 2: "history_len" }
735+
}
736+
737+
def build_prompt(self, query):
738+
return f'Instruct: {query}\nOutput:'
739+
740+
def get_attention_mask(self) -> torch.Tensor:
741+
if self.token_len:
742+
return torch.zeros([1, 1, 1, 1]).bool()
743+
attention_mask = ~torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool())
744+
return attention_mask
745+
746+
def get_position_ids(self) -> torch.Tensor:
747+
if self.token_len:
748+
return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
749+
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
750+
666751
if __name__ == '__main__':
667752
llm_models = {
668753
'chatglm-6b': Chatglm_6b,
@@ -672,7 +757,8 @@ def get_position_ids(self) -> torch.Tensor:
672757
'Qwen-7B-Chat': Qwen_7b_Chat,
673758
'Qwen-1_8B-Chat': Qwen_7b_Chat,
674759
'Baichuan2-7B-Chat': Llama2_7b_Chat,
675-
'Llama-2-7b-chat-ms': Llama2_7b_Chat
760+
'Llama-2-7b-chat-ms': Llama2_7b_Chat,
761+
'phi-2': phi_2
676762
}
677763
parser = argparse.ArgumentParser(description='LLMExporter', formatter_class=argparse.RawTextHelpFormatter)
678764
parser.add_argument('--path', type=str, default='THUDM/chatglm-6b', required=True,

0 commit comments

Comments
 (0)