diff --git a/llm_export.py b/llm_export.py index e9ea7b3..f63f01c 100644 --- a/llm_export.py +++ b/llm_export.py @@ -58,7 +58,7 @@ def forward(self, input_ids): res = self.embed(input_ids) if self.bf16: res = res.float() - return res.view(-1, 1, self.embed_dim) + return res.view(1, -1, self.embed_dim) class Lm(torch.nn.Module): def __init__(self, lm): @@ -67,8 +67,8 @@ def __init__(self, lm): def forward(self, hidden_states): m_logits = self.lm(hidden_states) - # token = torch.argmax(m_logits) - return m_logits + token = torch.argmax(m_logits) + return token class LLM(torch.nn.Module): ''' @@ -323,7 +323,7 @@ def export_embed(self): def export_block(self, block_id: int): self.seq_len = 3 self.token_len = 0 - inputs_embeds = torch.randn((self.seq_len, 1, self.hidden_size)) + inputs_embeds = torch.randn((1, self.seq_len, self.hidden_size)) attention_mask = self.get_attention_mask() position_ids = self.get_position_ids() past_key_values = torch.zeros(self.past_kv_shape[1:]) @@ -841,14 +841,14 @@ def load_model(self): self.past_kv_shape = [24, 2, 1, 0, 16, 128] # some config for export self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, + "inputs_embeds" : { 1: "seq_len" }, + "attention_mask" : { 2: "seq_len", 3: "history_len"}, "position_ids" : { 0: "seq_len" }, "past_key_values" : { 2: "history_len" } } self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, + "input_ids" : { 1: "seq_len" }, + "attention_mask" : { 2: "seq_len", 3: "history_len" }, "position_ids" : { 0: "seq_len" }, "past_key_values" : { 3: "history_len" } }