Skip to content

Commit 137193f

Browse files
committed
[bugfix] bugfix of onnx dynamic axis.
1 parent 5b462ce commit 137193f

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

llmexport/llmexport.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,8 +1322,8 @@ def load_model(self, model_path):
13221322
self.model_dynamic_axes = {
13231323
"input_ids" : { 0: "seq_len" },
13241324
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
1325-
"position_ids" : { 0: "seq_len" },
1326-
"past_key_values" : { 2: "history_len" }
1325+
"position_ids" : { 1: "seq_len" },
1326+
"past_key_values" : { 3: "history_len" }
13271327
}
13281328
self.llm_config = {
13291329
'hidden_size' : self.hidden_size,
@@ -1370,8 +1370,8 @@ def get_position_ids(self) -> torch.Tensor:
13701370
if self.model_type == 'chatglm':
13711371
return self.chatglm_position_ids()
13721372
if self.token_len:
1373-
return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
1374-
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
1373+
return torch.tensor([[self.seq_len - 1]], dtype=torch.int)
1374+
return torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0)
13751375

13761376
def chatglm_attention_mask(self):
13771377
if self.token_len:
@@ -1385,8 +1385,8 @@ def chatglm_attention_mask(self):
13851385
def chatglm_position_ids(self):
13861386
if self.token_len:
13871387
return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1])
1388-
position_ids_0 = torch.arange(self.seq_len, dtype=torch.long)
1389-
position_ids_1 = torch.zeros(self.seq_len, dtype=torch.long)
1388+
position_ids_0 = torch.arange(self.seq_len, dtype=torch.int)
1389+
position_ids_1 = torch.zeros(self.seq_len, dtype=torch.int)
13901390
position_ids_0[-1] = position_ids_0[-2]
13911391
position_ids_1[-1] = 1
13921392
position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1)

0 commit comments

Comments
 (0)