From 4179fa2242809af9f10be435ac81ce7cb183d26a Mon Sep 17 00:00:00 2001 From: xiaobai52HZ <448449153@qq.com> Date: Thu, 12 Oct 2023 10:44:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BB=A3=E7=A0=81=E4=BB=A5?= =?UTF-8?q?=E4=BE=BF=E5=8F=AF=E4=BB=A5=E4=BB=8Eonnx=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E4=B8=BAtrt=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_export.py | 2 +- llm_models/Llama-2-7b-chat-ms/modeling_llama.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/llm_export.py b/llm_export.py index 4915ea2..9a7293f 100644 --- a/llm_export.py +++ b/llm_export.py @@ -558,7 +558,7 @@ def load_model(self, model_path: str): self.model_dynamic_axes = { "input_ids" : { 0: "seq_len" }, "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, + "position_ids" : { 1: "seq_len" }, "past_key_values" : { 4: "history_len" } } diff --git a/llm_models/Llama-2-7b-chat-ms/modeling_llama.py b/llm_models/Llama-2-7b-chat-ms/modeling_llama.py index 8c562c6..625e3ff 100644 --- a/llm_models/Llama-2-7b-chat-ms/modeling_llama.py +++ b/llm_models/Llama-2-7b-chat-ms/modeling_llama.py @@ -180,10 +180,15 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + #cos = torch.squeeze(cos) # [seq_len, dim] + #sin = torch.squeeze(sin) # [seq_len, dim] + #cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + #sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos.view(cos.shape[2], cos.shape[3]) # [seq_len, dim] + cos = cos[position_ids].view(position_ids.shape[0], 1, position_ids.shape[1], cos.shape[1]) # [bs, 1, seq_len, dim] + sin = sin.view(sin.shape[2], sin.shape[3]) # [seq_len, dim] + sin = sin[position_ids].view(position_ids.shape[0], 1, position_ids.shape[1], sin.shape[1]) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed