Skip to content

Commit f01056e

Browse files
committed
🐛 Fix default for sharegpt dataset path
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
1 parent adda790 commit f01056e

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ def __download_file(url, filename):
3737
try:
3838
response = requests.get(url, stream=True)
3939
response.raise_for_status()
40-
40+
4141
with open(filename, 'wb') as file:
4242
for chunk in response.iter_content(chunk_size=8192):
4343
file.write(chunk)
4444
print(f"Successfully downloaded {filename}")
45-
45+
4646
except requests.exceptions.RequestException as e:
4747
print(f"An error occurred: {e}")
4848

4949
def __sample_requests(
50-
prompt_list: List[str],
50+
prompt_list: List[str],
5151
num_requests: int,
5252
tokenizer: BaseTokenizer,
5353
prompt_length_min: int = 32,
@@ -67,15 +67,15 @@ def __sample_requests(
6767
# Tokenize the prompts and completions.
6868
prompt = prompt_list[i]
6969
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
70-
70+
7171
prompt_len = len(prompt_token_ids)
7272
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
7373
# Prune too short or too long sequences.
7474
continue
7575
filtered_dataset.append((prompt, prompt_len))
76-
76+
7777
return filtered_dataset
78-
78+
7979

8080

8181
def sample_sharegpt_requests(
@@ -96,15 +96,15 @@ def sample_sharegpt_requests(
9696
# Filter out the conversations with less than 2 turns.
9797
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
9898
dataset = [data["conversations"][0]["value"] for data in dataset]
99-
99+
100100
return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
101101

102102
def sample_squad_v2_qa_requests(
103103
dataset_path: str,
104-
num_requests: int,
105-
tokenizer: BaseTokenizer,
106-
prompt_length_min: int = 32,
107-
prompt_length_max: int = 64,
104+
num_requests: int,
105+
tokenizer: BaseTokenizer,
106+
prompt_length_min: int = 32,
107+
prompt_length_max: int = 64,
108108
seed: Optional[int] = None
109109
) -> List[Tuple[str, int]]:
110110
from datasets import load_dataset
@@ -113,10 +113,10 @@ def sample_squad_v2_qa_requests(
113113
ds = load_dataset(dataset_path)['train']
114114
else:
115115
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']
116-
117-
116+
117+
118118
ds = [f"{data['context']}\n{data['question']}" for data in ds]
119119

120120
return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
121-
121+
122122

scripts/generate_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@
7878
parser.add_argument(
7979
"--sharegpt_path",
8080
type=str,
81-
help="path to sharegpt data json",
82-
required=False,
81+
help="path to sharegpt data json. If it is not available, then use target path",
82+
required=True,
8383
)
8484
parser.add_argument(
8585
"--output_dir",

0 commit comments

Comments
 (0)