@@ -1322,8 +1322,8 @@ def load_model(self, model_path):
13221322 self .model_dynamic_axes = {
13231323 "input_ids" : { 0 : "seq_len" },
13241324 "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
1325- "position_ids" : { 0 : "seq_len" },
1326- "past_key_values" : { 2 : "history_len" }
1325+ "position_ids" : { 1 : "seq_len" },
1326+ "past_key_values" : { 3 : "history_len" }
13271327 }
13281328 self .llm_config = {
13291329 'hidden_size' : self .hidden_size ,
@@ -1370,8 +1370,8 @@ def get_position_ids(self) -> torch.Tensor:
13701370 if self .model_type == 'chatglm' :
13711371 return self .chatglm_position_ids ()
13721372 if self .token_len :
1373- return torch .tensor ([[self .seq_len - 1 ]], dtype = torch .long )
1374- return torch .arange (self .seq_len , dtype = torch .long ).unsqueeze (0 )
1373+ return torch .tensor ([[self .seq_len - 1 ]], dtype = torch .int )
1374+ return torch .arange (self .seq_len , dtype = torch .int ).unsqueeze (0 )
13751375
13761376 def chatglm_attention_mask (self ):
13771377 if self .token_len :
@@ -1385,8 +1385,8 @@ def chatglm_attention_mask(self):
13851385 def chatglm_position_ids (self ):
13861386 if self .token_len :
13871387 return torch .tensor ([self .context_len , self .token_len + 1 ]).reshape ([1 , 2 , 1 ])
1388- position_ids_0 = torch .arange (self .seq_len , dtype = torch .long )
1389- position_ids_1 = torch .zeros (self .seq_len , dtype = torch .long )
1388+ position_ids_0 = torch .arange (self .seq_len , dtype = torch .int )
1389+ position_ids_1 = torch .zeros (self .seq_len , dtype = torch .int )
13901390 position_ids_0 [- 1 ] = position_ids_0 [- 2 ]
13911391 position_ids_1 [- 1 ] = 1
13921392 position_ids = torch .stack ([position_ids_0 , position_ids_1 ]).view (1 , 2 , - 1 )
0 commit comments