|
21 | 21 | from ._provider import Provider |
22 | 22 | from ._tokens import tokens_log |
23 | 23 | from ._tools import Tool, basemodel_to_param_schema |
24 | | -from ._turn import Turn, normalize_turns |
| 24 | +from ._turn import Turn, normalize_turns, user_turn |
25 | 25 | from ._utils import MISSING, MISSING_TYPE, is_testing |
26 | 26 |
|
27 | 27 | if TYPE_CHECKING: |
@@ -366,22 +366,21 @@ def token_count( |
366 | 366 |
|
367 | 367 | encoding = tiktoken.encoding_for_model(self._model) |
368 | 368 |
|
369 | | - res: int = 0 |
370 | | - for arg in args: |
371 | | - if isinstance(arg, str): |
372 | | - res += len(encoding.encode(arg)) |
373 | | - elif isinstance(arg, ContentText): |
374 | | - res += len(encoding.encode(arg.text)) |
375 | | - elif isinstance(arg, ContentImage): |
376 | | - res += self._image_token_count(arg) |
377 | | - elif isinstance(arg, ContentToolResult): |
378 | | - res += len(encoding.encode(arg.get_final_value())) |
379 | | - else: |
380 | | - raise NotImplementedError( |
381 | | - f"Token counting for {type(arg)} not yet implemented." |
382 | | - ) |
| 369 | + turn = user_turn(*args) |
383 | 370 |
|
384 | | - return res |
| 371 | + # Count the tokens in image contents |
| 372 | + image_tokens = sum( |
| 373 | + self._image_token_count(x) |
| 374 | + for x in turn.contents |
| 375 | + if isinstance(x, ContentImage) |
| 376 | + ) |
| 377 | + |
| 378 | + # For other contents, get the token count from the actual message param |
| 379 | + other_contents = [x for x in turn.contents if not isinstance(x, ContentImage)] |
| 380 | + other_full = self._as_message_param([Turn("user", other_contents)]) |
| 381 | + other_tokens = len(encoding.encode(str(other_full))) |
| 382 | + |
| 383 | + return other_tokens + image_tokens |
385 | 384 |
|
386 | 385 | async def token_count_async( |
387 | 386 | self, |
|
0 commit comments