diff --git a/llm_export.py b/llm_export.py index 4915ea2..9a7293f 100644 --- a/llm_export.py +++ b/llm_export.py @@ -558,7 +558,7 @@ def load_model(self, model_path: str): self.model_dynamic_axes = { "input_ids" : { 0: "seq_len" }, "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, + "position_ids" : { 1: "seq_len" }, "past_key_values" : { 4: "history_len" } } diff --git a/llm_models/Llama-2-7b-chat-ms/modeling_llama.py b/llm_models/Llama-2-7b-chat-ms/modeling_llama.py index 8c562c6..625e3ff 100644 --- a/llm_models/Llama-2-7b-chat-ms/modeling_llama.py +++ b/llm_models/Llama-2-7b-chat-ms/modeling_llama.py @@ -180,10 +180,15 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + #cos = torch.squeeze(cos) # [seq_len, dim] + #sin = torch.squeeze(sin) # [seq_len, dim] + #cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + #sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos.view(cos.shape[2], cos.shape[3]) # [seq_len, dim] + cos = cos[position_ids].view(position_ids.shape[0], 1, position_ids.shape[1], cos.shape[1]) # [bs, 1, seq_len, dim] + sin = sin.view(sin.shape[2], sin.shape[3]) # [seq_len, dim] + sin = sin[position_ids].view(position_ids.shape[0], 1, position_ids.shape[1], sin.shape[1]) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed