@@ -154,11 +154,16 @@ def tokenize_trajectory(
154154 return None
155155 messages_and_choices = history .messages_and_choices [: last_assistant_index + 1 ]
156156 messages = get_messages (messages_and_choices )
157+ tools = (
158+ [{"type" : "function" , "function" : tool } for tool in history .tools ]
159+ if history .tools is not None
160+ else None
161+ )
157162 chat = cast (
158163 str ,
159164 tokenizer .apply_chat_template (
160165 cast (list [dict ], messages ),
161- tools = history . tools , # type: ignore
166+ tools = tools ,
162167 continue_final_message = True ,
163168 tokenize = False ,
164169 ),
@@ -167,7 +172,7 @@ def tokenize_trajectory(
167172 list [int ],
168173 tokenizer .apply_chat_template (
169174 cast (list [dict ], messages ),
170- tools = history . tools , # type: ignore
175+ tools = tools ,
171176 continue_final_message = True ,
172177 ),
173178 )
@@ -193,7 +198,7 @@ def tokenize_trajectory(
193198 for message_or_choice in messages_and_choices
194199 ],
195200 ),
196- tools = history . tools , # type: ignore
201+ tools = tools ,
197202 continue_final_message = True ,
198203 ),
199204 )
@@ -204,6 +209,10 @@ def tokenize_trajectory(
204209 continue
205210 start = token_ids .index (sentinal_token_id )
206211 end = start + 1
212+ try :
213+ end_token_id = token_ids [end ]
214+ except IndexError :
215+ end_token_id = None
207216 if isinstance (message , dict ):
208217 content = message .get ("content" )
209218 assert isinstance (content , str )
@@ -247,6 +256,10 @@ def tokenize_trajectory(
247256 token_logprob .logprob for token_logprob in token_logprobs
248257 )
249258 assistant_mask [start :end ] = [1 ] * len (token_logprobs )
259+ if token_ids [start + len (token_logprobs ) - 1 ] == end_token_id :
260+ token_ids .pop (start + len (token_logprobs ))
261+ logprobs .pop (start + len (token_logprobs ))
262+ assistant_mask .pop (start + len (token_logprobs ))
250263 if image_processor :
251264 images : list [Image .Image ] = []
252265 for message in messages_and_choices :
0 commit comments