Skip to content

Commit 3a95f28

Browse files
committed
fix: Address some tokenization issues
1 parent 5e3c66c commit 3a95f28

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

src/art/preprocessing/tokenize.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,16 @@ def tokenize_trajectory(
154154
return None
155155
messages_and_choices = history.messages_and_choices[: last_assistant_index + 1]
156156
messages = get_messages(messages_and_choices)
157+
tools = (
158+
[{"type": "function", "function": tool} for tool in history.tools]
159+
if history.tools is not None
160+
else None
161+
)
157162
chat = cast(
158163
str,
159164
tokenizer.apply_chat_template(
160165
cast(list[dict], messages),
161-
tools=history.tools, # type: ignore
166+
tools=tools,
162167
continue_final_message=True,
163168
tokenize=False,
164169
),
@@ -167,7 +172,7 @@ def tokenize_trajectory(
167172
list[int],
168173
tokenizer.apply_chat_template(
169174
cast(list[dict], messages),
170-
tools=history.tools, # type: ignore
175+
tools=tools,
171176
continue_final_message=True,
172177
),
173178
)
@@ -193,7 +198,7 @@ def tokenize_trajectory(
193198
for message_or_choice in messages_and_choices
194199
],
195200
),
196-
tools=history.tools, # type: ignore
201+
tools=tools,
197202
continue_final_message=True,
198203
),
199204
)
@@ -204,6 +209,10 @@ def tokenize_trajectory(
204209
continue
205210
start = token_ids.index(sentinal_token_id)
206211
end = start + 1
212+
try:
213+
end_token_id = token_ids[end]
214+
except IndexError:
215+
end_token_id = None
207216
if isinstance(message, dict):
208217
content = message.get("content")
209218
assert isinstance(content, str)
@@ -247,6 +256,10 @@ def tokenize_trajectory(
247256
token_logprob.logprob for token_logprob in token_logprobs
248257
)
249258
assistant_mask[start:end] = [1] * len(token_logprobs)
259+
if token_ids[start + len(token_logprobs) - 1] == end_token_id:
260+
token_ids.pop(start + len(token_logprobs))
261+
logprobs.pop(start + len(token_logprobs))
262+
assistant_mask.pop(start + len(token_logprobs))
250263
if image_processor:
251264
images: list[Image.Image] = []
252265
for message in messages_and_choices:

0 commit comments

Comments
 (0)