Skip to content

Commit 974194b

Browse files
authored
Merge pull request #130 from kcirred/print_size
[dpp] eliminated pad_token_id from print
2 parents a531fca + f62a7a6 commit 974194b

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

scripts/drive_paged_programs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import random
66
import time
7+
from itertools import dropwhile
78

89
import torch
910
from fms.models import get_model
@@ -46,7 +47,7 @@
4647
nargs="*",
4748
default=[],
4849
help="""
49-
The list of programs to run. This would take a list where each element would be one of program_id OR <program_id>:<min_batch>,<min_prompt_length>.
50+
The list of programs to run. This would take a list where each element would be one of program_id OR <program_id>:<min_batch>,<min_prompt_length>.
5051
If program_id is specified any prompt that would result in this program would be selected.
5152
If <program_id>:<min_batch>,<min_prompt_length> is specified, then with the given program_id, select a prompt that satisfies min_batch and min_prompt_length (if none exists, a message will be printed to warn the user)
5253
If this list is empty, each program will be run once with any prompt that would result in this program being selected.
@@ -565,8 +566,15 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
565566
aiu_tokens_generated = [
566567
t.item() for t in test_sentence[-max_new_tokens:]
567568
]
569+
tokens_prompt_without_pad = list(
570+
dropwhile(lambda x: x == tokenizer.pad_token_id, tokens_prompt)
571+
)
572+
prompt_length = len(
573+
[token_id for token_id in tokens_prompt_without_pad]
574+
)
575+
dprint(f"Prompt Length: {prompt_length}")
568576
dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:")
569-
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}")
577+
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt_without_pad)}")
570578
dprint(f"CPU tokens:\n{cpu_tokens_generated}")
571579
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
572580
dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}")

0 commit comments

Comments
 (0)